diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py index be000248498..eb0ba18cdad 100644 --- a/django/db/backends/mysql/base.py +++ b/django/db/backends/mysql/base.py @@ -364,3 +364,10 @@ class DatabaseWrapper(BaseDatabaseWrapper): @cached_property def mysql_is_mariadb(self): return 'mariadb' in self.mysql_server_info.lower() + + @cached_property + def sql_mode(self): + with self.cursor() as cursor: + cursor.execute('SELECT @@sql_mode') + sql_mode = cursor.fetchone() + return set(sql_mode[0].split(',') if sql_mode else ()) diff --git a/django/db/backends/mysql/validation.py b/django/db/backends/mysql/validation.py index 9def8e57f8b..ee1c360e351 100644 --- a/django/db/backends/mysql/validation.py +++ b/django/db/backends/mysql/validation.py @@ -10,11 +10,7 @@ class DatabaseValidation(BaseDatabaseValidation): return issues def _check_sql_mode(self, **kwargs): - with self.connection.cursor() as cursor: - cursor.execute("SELECT @@sql_mode") - sql_mode = cursor.fetchone() - modes = set(sql_mode[0].split(',') if sql_mode else ()) - if not (modes & {'STRICT_TRANS_TABLES', 'STRICT_ALL_TABLES'}): + if not (self.connection.sql_mode & {'STRICT_TRANS_TABLES', 'STRICT_ALL_TABLES'}): return [checks.Warning( "MySQL Strict Mode is not set for database connection '%s'" % self.connection.alias, hint="MySQL's Strict Mode fixes many data integrity problems in MySQL, " diff --git a/tests/check_framework/test_database.py b/tests/check_framework/test_database.py index bf291b24a1e..6e6b4e34683 100644 --- a/tests/check_framework/test_database.py +++ b/tests/check_framework/test_database.py @@ -2,7 +2,7 @@ import unittest from unittest import mock from django.core.checks.database import check_database_backends -from django.db import connection +from django.db import connection, connections from django.test import TestCase @@ -18,6 +18,12 @@ class DatabaseCheckTests(TestCase): @unittest.skipUnless(connection.vendor == 'mysql', 'Test only for MySQL') def test_mysql_strict_mode(self): + def _clean_sql_mode(): + for alias in self.databases: + if hasattr(connections[alias], 'sql_mode'): + del connections[alias].sql_mode + + _clean_sql_mode() good_sql_modes = [ 'STRICT_TRANS_TABLES,STRICT_ALL_TABLES', 'STRICT_TRANS_TABLES', @@ -29,6 +35,7 @@ class DatabaseCheckTests(TestCase): return_value=(response,) ): self.assertEqual(check_database_backends(databases=self.databases), []) + _clean_sql_mode() bad_sql_modes = ['', 'WHATEVER'] for response in bad_sql_modes: @@ -40,3 +47,4 @@ class DatabaseCheckTests(TestCase): result = check_database_backends(databases=self.databases) self.assertEqual(len(result), 2) self.assertEqual([r.id for r in result], ['mysql.W002', 'mysql.W002']) + _clean_sql_mode()