This commit is contained in:
parent
7a6dbbb655
commit
2b2ae4eeb7
|
@ -1,6 +1,8 @@
|
|||
import datetime
|
||||
import decimal
|
||||
import uuid
|
||||
from functools import lru_cache
|
||||
from itertools import chain
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import FieldError
|
||||
|
@ -11,6 +13,7 @@ from django.db.models.expressions import Col
|
|||
from django.utils import timezone
|
||||
from django.utils.dateparse import parse_date, parse_datetime, parse_time
|
||||
from django.utils.duration import duration_microseconds
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
class DatabaseOperations(BaseDatabaseOperations):
|
||||
|
@ -158,28 +161,36 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
def no_limit_value(self):
|
||||
return -1
|
||||
|
||||
def __references_graph(self, table_name):
|
||||
query = """
|
||||
WITH tables AS (
|
||||
SELECT %s name
|
||||
UNION
|
||||
SELECT sqlite_master.name
|
||||
FROM sqlite_master
|
||||
JOIN tables ON (sql REGEXP %s || tables.name || %s)
|
||||
) SELECT name FROM tables;
|
||||
"""
|
||||
params = (
|
||||
table_name,
|
||||
r'(?i)\s+references\s+("|\')?',
|
||||
r'("|\')?\s*\(',
|
||||
)
|
||||
with self.connection.cursor() as cursor:
|
||||
results = cursor.execute(query, params)
|
||||
return [row[0] for row in results.fetchall()]
|
||||
|
||||
@cached_property
|
||||
def _references_graph(self):
|
||||
# 512 is large enough to fit the ~330 tables (as of this writing) in
|
||||
# Django's test suite.
|
||||
return lru_cache(maxsize=512)(self.__references_graph)
|
||||
|
||||
def sql_flush(self, style, tables, sequences, allow_cascade=False):
|
||||
if tables and allow_cascade:
|
||||
# Simulate TRUNCATE CASCADE by recursively collecting the tables
|
||||
# referencing the tables to be flushed.
|
||||
query = """
|
||||
WITH tables AS (
|
||||
%s
|
||||
UNION
|
||||
SELECT sqlite_master.name
|
||||
FROM sqlite_master
|
||||
JOIN tables ON (
|
||||
sql REGEXP %%s || tables.name || %%s
|
||||
)
|
||||
) SELECT name FROM tables;
|
||||
""" % ' UNION '.join("SELECT '%s' name" % table for table in tables)
|
||||
params = (
|
||||
r'(?i)\s+references\s+("|\')?',
|
||||
r'("|\')?\s*\(',
|
||||
)
|
||||
with self.connection.cursor() as cursor:
|
||||
results = cursor.execute(query, params)
|
||||
tables = [row[0] for row in results.fetchall()]
|
||||
tables = set(chain.from_iterable(self._references_graph(table) for table in tables))
|
||||
sql = ['%s %s %s;' % (
|
||||
style.SQL_KEYWORD('DELETE'),
|
||||
style.SQL_KEYWORD('FROM'),
|
||||
|
|
Loading…
Reference in New Issue