Fixed #33340 -- Fixed unquoted column names in queries used by DatabaseCache.

This commit is contained in:
Arsa 2021-12-06 16:55:29 +01:00 committed by Mariusz Felisiak
parent eba9a9b7f7
commit 17df72114e
5 changed files with 36 additions and 7 deletions

View File

@ -97,6 +97,7 @@ answer newbie questions, and generally made Django that much better:
arien <regexbot@gmail.com> arien <regexbot@gmail.com>
Armin Ronacher Armin Ronacher
Aron Podrigal <aronp@guaranteedplus.com> Aron Podrigal <aronp@guaranteedplus.com>
Arsalan Ghassemi <arsalan.ghassemi@gmail.com>
Artem Gnilov <boobsd@gmail.com> Artem Gnilov <boobsd@gmail.com>
Arthur <avandorp@gmail.com> Arthur <avandorp@gmail.com>
Arthur Jovart <arthur@jovart.com> Arthur Jovart <arthur@jovart.com>

View File

@ -228,10 +228,11 @@ class DatabaseCache(BaseDatabaseCache):
with connection.cursor() as cursor: with connection.cursor() as cursor:
cursor.execute( cursor.execute(
'SELECT %s FROM %s WHERE %s = %%s and expires > %%s' % ( 'SELECT %s FROM %s WHERE %s = %%s and %s > %%s' % (
quote_name('cache_key'), quote_name('cache_key'),
quote_name(self._table), quote_name(self._table),
quote_name('cache_key'), quote_name('cache_key'),
quote_name('expires'),
), ),
[key, connection.ops.adapt_datetimefield_value(now)] [key, connection.ops.adapt_datetimefield_value(now)]
) )
@ -243,8 +244,10 @@ class DatabaseCache(BaseDatabaseCache):
else: else:
connection = connections[db] connection = connections[db]
table = connection.ops.quote_name(self._table) table = connection.ops.quote_name(self._table)
cursor.execute("DELETE FROM %s WHERE expires < %%s" % table, cursor.execute('DELETE FROM %s WHERE %s < %%s' % (
[connection.ops.adapt_datetimefield_value(now)]) table,
connection.ops.quote_name('expires'),
), [connection.ops.adapt_datetimefield_value(now)])
deleted_count = cursor.rowcount deleted_count = cursor.rowcount
remaining_num = num - deleted_count remaining_num = num - deleted_count
if remaining_num > self._max_entries: if remaining_num > self._max_entries:
@ -255,7 +258,10 @@ class DatabaseCache(BaseDatabaseCache):
last_cache_key = cursor.fetchone() last_cache_key = cursor.fetchone()
if last_cache_key: if last_cache_key:
cursor.execute( cursor.execute(
'DELETE FROM %s WHERE cache_key < %%s' % table, 'DELETE FROM %s WHERE %s < %%s' % (
table,
connection.ops.quote_name('cache_key'),
),
[last_cache_key[0]], [last_cache_key[0]],
) )

View File

@ -88,7 +88,8 @@ class BaseDatabaseOperations:
This is used by the 'db' cache backend to determine where to start This is used by the 'db' cache backend to determine where to start
culling. culling.
""" """
return "SELECT cache_key FROM %s ORDER BY cache_key LIMIT 1 OFFSET %%s" cache_key = self.quote_name('cache_key')
return f'SELECT {cache_key} FROM %s ORDER BY {cache_key} LIMIT 1 OFFSET %%s'
def unification_cast_sql(self, output_field): def unification_cast_sql(self, output_field):
""" """

View File

@ -72,7 +72,12 @@ END;
} }
def cache_key_culling_sql(self): def cache_key_culling_sql(self):
return 'SELECT cache_key FROM %s ORDER BY cache_key OFFSET %%s ROWS FETCH FIRST 1 ROWS ONLY' cache_key = self.quote_name('cache_key')
return (
f'SELECT {cache_key} '
f'FROM %s '
f'ORDER BY {cache_key} OFFSET %%s ROWS FETCH FIRST 1 ROWS ONLY'
)
def date_extract_sql(self, lookup_type, field_name): def date_extract_sql(self, lookup_type, field_name):
if lookup_type == 'week_day': if lookup_type == 'week_day':

18
tests/cache/tests.py vendored
View File

@ -1113,7 +1113,7 @@ class DBCacheTests(BaseCacheTests, TransactionTestCase):
with self.assertNumQueries(1): with self.assertNumQueries(1):
cache.delete_many(['a', 'b', 'c']) cache.delete_many(['a', 'b', 'c'])
def test_cull_count_queries(self): def test_cull_queries(self):
old_max_entries = cache._max_entries old_max_entries = cache._max_entries
# Force _cull to delete on first cached record. # Force _cull to delete on first cached record.
cache._max_entries = -1 cache._max_entries = -1
@ -1124,6 +1124,13 @@ class DBCacheTests(BaseCacheTests, TransactionTestCase):
cache._max_entries = old_max_entries cache._max_entries = old_max_entries
num_count_queries = sum('COUNT' in query['sql'] for query in captured_queries) num_count_queries = sum('COUNT' in query['sql'] for query in captured_queries)
self.assertEqual(num_count_queries, 1) self.assertEqual(num_count_queries, 1)
# Column names are quoted.
for query in captured_queries:
sql = query['sql']
if 'expires' in sql:
self.assertIn(connection.ops.quote_name('expires'), sql)
if 'cache_key' in sql:
self.assertIn(connection.ops.quote_name('cache_key'), sql)
def test_delete_cursor_rowcount(self): def test_delete_cursor_rowcount(self):
""" """
@ -1180,6 +1187,15 @@ class DBCacheTests(BaseCacheTests, TransactionTestCase):
) )
self.assertEqual(out.getvalue(), "Cache table 'test cache table' created.\n") self.assertEqual(out.getvalue(), "Cache table 'test cache table' created.\n")
def test_has_key_query_columns_quoted(self):
with CaptureQueriesContext(connection) as captured_queries:
cache.has_key('key')
self.assertEqual(len(captured_queries), 1)
sql = captured_queries[0]['sql']
# Column names are quoted.
self.assertIn(connection.ops.quote_name('expires'), sql)
self.assertIn(connection.ops.quote_name('cache_key'), sql)
@override_settings(USE_TZ=True) @override_settings(USE_TZ=True)
class DBCacheWithTimeZoneTests(DBCacheTests): class DBCacheWithTimeZoneTests(DBCacheTests):