This commit is contained in:
parent
7a6dbbb655
commit
2b2ae4eeb7
|
@ -1,6 +1,8 @@
|
||||||
import datetime
|
import datetime
|
||||||
import decimal
|
import decimal
|
||||||
import uuid
|
import uuid
|
||||||
|
from functools import lru_cache
|
||||||
|
from itertools import chain
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.core.exceptions import FieldError
|
from django.core.exceptions import FieldError
|
||||||
|
@ -11,6 +13,7 @@ from django.db.models.expressions import Col
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
from django.utils.dateparse import parse_date, parse_datetime, parse_time
|
from django.utils.dateparse import parse_date, parse_datetime, parse_time
|
||||||
from django.utils.duration import duration_microseconds
|
from django.utils.duration import duration_microseconds
|
||||||
|
from django.utils.functional import cached_property
|
||||||
|
|
||||||
|
|
||||||
class DatabaseOperations(BaseDatabaseOperations):
|
class DatabaseOperations(BaseDatabaseOperations):
|
||||||
|
@ -158,28 +161,36 @@ class DatabaseOperations(BaseDatabaseOperations):
|
||||||
def no_limit_value(self):
|
def no_limit_value(self):
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
|
def __references_graph(self, table_name):
|
||||||
|
query = """
|
||||||
|
WITH tables AS (
|
||||||
|
SELECT %s name
|
||||||
|
UNION
|
||||||
|
SELECT sqlite_master.name
|
||||||
|
FROM sqlite_master
|
||||||
|
JOIN tables ON (sql REGEXP %s || tables.name || %s)
|
||||||
|
) SELECT name FROM tables;
|
||||||
|
"""
|
||||||
|
params = (
|
||||||
|
table_name,
|
||||||
|
r'(?i)\s+references\s+("|\')?',
|
||||||
|
r'("|\')?\s*\(',
|
||||||
|
)
|
||||||
|
with self.connection.cursor() as cursor:
|
||||||
|
results = cursor.execute(query, params)
|
||||||
|
return [row[0] for row in results.fetchall()]
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _references_graph(self):
|
||||||
|
# 512 is large enough to fit the ~330 tables (as of this writing) in
|
||||||
|
# Django's test suite.
|
||||||
|
return lru_cache(maxsize=512)(self.__references_graph)
|
||||||
|
|
||||||
def sql_flush(self, style, tables, sequences, allow_cascade=False):
|
def sql_flush(self, style, tables, sequences, allow_cascade=False):
|
||||||
if tables and allow_cascade:
|
if tables and allow_cascade:
|
||||||
# Simulate TRUNCATE CASCADE by recursively collecting the tables
|
# Simulate TRUNCATE CASCADE by recursively collecting the tables
|
||||||
# referencing the tables to be flushed.
|
# referencing the tables to be flushed.
|
||||||
query = """
|
tables = set(chain.from_iterable(self._references_graph(table) for table in tables))
|
||||||
WITH tables AS (
|
|
||||||
%s
|
|
||||||
UNION
|
|
||||||
SELECT sqlite_master.name
|
|
||||||
FROM sqlite_master
|
|
||||||
JOIN tables ON (
|
|
||||||
sql REGEXP %%s || tables.name || %%s
|
|
||||||
)
|
|
||||||
) SELECT name FROM tables;
|
|
||||||
""" % ' UNION '.join("SELECT '%s' name" % table for table in tables)
|
|
||||||
params = (
|
|
||||||
r'(?i)\s+references\s+("|\')?',
|
|
||||||
r'("|\')?\s*\(',
|
|
||||||
)
|
|
||||||
with self.connection.cursor() as cursor:
|
|
||||||
results = cursor.execute(query, params)
|
|
||||||
tables = [row[0] for row in results.fetchall()]
|
|
||||||
sql = ['%s %s %s;' % (
|
sql = ['%s %s %s;' % (
|
||||||
style.SQL_KEYWORD('DELETE'),
|
style.SQL_KEYWORD('DELETE'),
|
||||||
style.SQL_KEYWORD('FROM'),
|
style.SQL_KEYWORD('FROM'),
|
||||||
|
|
Loading…
Reference in New Issue