diff --git a/django/db/backends/sqlite3/operations.py b/django/db/backends/sqlite3/operations.py index 858e10ef84c..6ec4859f0e9 100644 --- a/django/db/backends/sqlite3/operations.py +++ b/django/db/backends/sqlite3/operations.py @@ -1,6 +1,8 @@ import datetime import decimal import uuid +from functools import lru_cache +from itertools import chain from django.conf import settings from django.core.exceptions import FieldError @@ -11,6 +13,7 @@ from django.db.models.expressions import Col from django.utils import timezone from django.utils.dateparse import parse_date, parse_datetime, parse_time from django.utils.duration import duration_microseconds +from django.utils.functional import cached_property class DatabaseOperations(BaseDatabaseOperations): @@ -158,28 +161,36 @@ class DatabaseOperations(BaseDatabaseOperations): def no_limit_value(self): return -1 + def __references_graph(self, table_name): + query = """ + WITH tables AS ( + SELECT %s name + UNION + SELECT sqlite_master.name + FROM sqlite_master + JOIN tables ON (sql REGEXP %s || tables.name || %s) + ) SELECT name FROM tables; + """ + params = ( + table_name, + r'(?i)\s+references\s+("|\')?', + r'("|\')?\s*\(', + ) + with self.connection.cursor() as cursor: + results = cursor.execute(query, params) + return [row[0] for row in results.fetchall()] + + @cached_property + def _references_graph(self): + # 512 is large enough to fit the ~330 tables (as of this writing) in + # Django's test suite. + return lru_cache(maxsize=512)(self.__references_graph) + def sql_flush(self, style, tables, sequences, allow_cascade=False): if tables and allow_cascade: # Simulate TRUNCATE CASCADE by recursively collecting the tables # referencing the tables to be flushed. - query = """ - WITH tables AS ( - %s - UNION - SELECT sqlite_master.name - FROM sqlite_master - JOIN tables ON ( - sql REGEXP %%s || tables.name || %%s - ) - ) SELECT name FROM tables; - """ % ' UNION '.join("SELECT '%s' name" % table for table in tables) - params = ( - r'(?i)\s+references\s+("|\')?', - r'("|\')?\s*\(', - ) - with self.connection.cursor() as cursor: - results = cursor.execute(query, params) - tables = [row[0] for row in results.fetchall()] + tables = set(chain.from_iterable(self._references_graph(table) for table in tables)) sql = ['%s %s %s;' % ( style.SQL_KEYWORD('DELETE'), style.SQL_KEYWORD('FROM'),