From 9ac3ef59f9538cfb520e3607af743532434d1755 Mon Sep 17 00:00:00 2001 From: Hasan Ramezani Date: Mon, 27 Dec 2021 19:04:59 +0100 Subject: [PATCH] Fixed #33379 -- Added minimum database version checks. Thanks Tim Graham for the review. --- django/db/backends/base/base.py | 34 ++++++++++++++++++---- django/db/backends/base/features.py | 2 ++ django/db/backends/mysql/base.py | 4 +++ django/db/backends/mysql/features.py | 7 +++++ django/db/backends/oracle/base.py | 4 +++ django/db/backends/oracle/features.py | 1 + django/db/backends/postgresql/base.py | 8 ++++++ django/db/backends/postgresql/features.py | 1 + django/db/backends/sqlite3/base.py | 15 ++-------- django/db/backends/sqlite3/features.py | 1 + docs/releases/4.1.txt | 12 ++++++++ tests/backends/base/test_base.py | 35 +++++++++++++++++++++++ tests/backends/mysql/tests.py | 19 +++++++++++- tests/backends/oracle/tests.py | 18 ++++++++++-- tests/backends/postgresql/tests.py | 20 ++++++++++++- tests/backends/sqlite/tests.py | 26 +++++------------ 16 files changed, 166 insertions(+), 41 deletions(-) diff --git a/django/db/backends/base/base.py b/django/db/backends/base/base.py index d2e79c1dd4..f093f2bd8b 100644 --- a/django/db/backends/base/base.py +++ b/django/db/backends/base/base.py @@ -13,7 +13,7 @@ except ImportError: from django.conf import settings from django.core.exceptions import ImproperlyConfigured -from django.db import DEFAULT_DB_ALIAS, DatabaseError +from django.db import DEFAULT_DB_ALIAS, DatabaseError, NotSupportedError from django.db.backends import utils from django.db.backends.base.validation import BaseDatabaseValidation from django.db.backends.signals import connection_created @@ -24,6 +24,7 @@ from django.utils.asyncio import async_unsafe from django.utils.functional import cached_property NO_DB_ALIAS = "__no_db__" +RAN_DB_VERSION_CHECK = set() # RemovedInDjango50Warning @@ -185,6 +186,29 @@ class BaseDatabaseWrapper: ) return list(self.queries_log) + def get_database_version(self): + """Return a tuple of the database's version.""" + raise NotImplementedError( + "subclasses of BaseDatabaseWrapper may require a get_database_version() " + "method." + ) + + def check_database_version_supported(self): + """ + Raise an error if the database version isn't supported by this + version of Django. + """ + if ( + self.features.minimum_database_version is not None + and self.get_database_version() < self.features.minimum_database_version + ): + db_version = ".".join(map(str, self.get_database_version())) + min_db_version = ".".join(map(str, self.features.minimum_database_version)) + raise NotSupportedError( + f"{self.display_name} {min_db_version} or later is required " + f"(found {db_version})." + ) + # ##### Backend-specific methods for creating connections and cursors ##### def get_connection_params(self): @@ -203,10 +227,10 @@ class BaseDatabaseWrapper: def init_connection_state(self): """Initialize the database connection settings.""" - raise NotImplementedError( - "subclasses of BaseDatabaseWrapper may require an init_connection_state() " - "method" - ) + global RAN_DB_VERSION_CHECK + if self.alias not in RAN_DB_VERSION_CHECK: + self.check_database_version_supported() + RAN_DB_VERSION_CHECK.add(self.alias) def create_cursor(self, name=None): """Create a cursor. Assume that a connection is established.""" diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py index 42399b769a..ccf9104c21 100644 --- a/django/db/backends/base/features.py +++ b/django/db/backends/base/features.py @@ -3,6 +3,8 @@ from django.utils.functional import cached_property class BaseDatabaseFeatures: + # An optional tuple indicating the minimum supported database version. + minimum_database_version = None gis_enabled = False # Oracle can't group by LOB (large object) data types. allows_group_by_lob = True diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py index 1c20554b4d..ca12917322 100644 --- a/django/db/backends/mysql/base.py +++ b/django/db/backends/mysql/base.py @@ -200,6 +200,9 @@ class DatabaseWrapper(BaseDatabaseWrapper): ops_class = DatabaseOperations validation_class = DatabaseValidation + def get_database_version(self): + return self.mysql_version + def get_connection_params(self): kwargs = { "conv": django_conversions, @@ -251,6 +254,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): return connection def init_connection_state(self): + super().init_connection_state() assignments = [] if self.features.is_sql_auto_is_null_enabled: # SQL_AUTO_IS_NULL controls whether an AUTO_INCREMENT column on diff --git a/django/db/backends/mysql/features.py b/django/db/backends/mysql/features.py index e27e766e48..f5b9ef9b55 100644 --- a/django/db/backends/mysql/features.py +++ b/django/db/backends/mysql/features.py @@ -48,6 +48,13 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_order_by_nulls_modifier = False order_by_nulls_first = True + @cached_property + def minimum_database_version(self): + if self.connection.mysql_is_mariadb: + return (10, 2) + else: + return (5, 7) + @cached_property def test_collations(self): charset = "utf8" diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py index 7cbee768ea..2ccd3bc028 100644 --- a/django/db/backends/oracle/base.py +++ b/django/db/backends/oracle/base.py @@ -239,6 +239,9 @@ class DatabaseWrapper(BaseDatabaseWrapper): ) self.features.can_return_columns_from_insert = use_returning_into + def get_database_version(self): + return self.oracle_version + def get_connection_params(self): conn_params = self.settings_dict["OPTIONS"].copy() if "use_returning_into" in conn_params: @@ -255,6 +258,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): ) def init_connection_state(self): + super().init_connection_state() cursor = self.create_cursor() # Set the territory first. The territory overrides NLS_DATE_FORMAT # and NLS_TIMESTAMP_FORMAT to the territory default. When all of diff --git a/django/db/backends/oracle/features.py b/django/db/backends/oracle/features.py index 257c3e1b87..2580e5c6c1 100644 --- a/django/db/backends/oracle/features.py +++ b/django/db/backends/oracle/features.py @@ -4,6 +4,7 @@ from django.utils.functional import cached_property class DatabaseFeatures(BaseDatabaseFeatures): + minimum_database_version = (19,) # Oracle crashes with "ORA-00932: inconsistent datatypes: expected - got # BLOB" when grouping by LOBs (#24096). allows_group_by_lob = False diff --git a/django/db/backends/postgresql/base.py b/django/db/backends/postgresql/base.py index 92f393227e..630da22964 100644 --- a/django/db/backends/postgresql/base.py +++ b/django/db/backends/postgresql/base.py @@ -153,6 +153,13 @@ class DatabaseWrapper(BaseDatabaseWrapper): # PostgreSQL backend-specific attributes. _named_cursor_idx = 0 + def get_database_version(self): + """ + Return a tuple of the database's version. + E.g. for pg_version 120004, return (12, 4). + """ + return divmod(self.pg_version, 10000) + def get_connection_params(self): settings_dict = self.settings_dict # None may be used to connect to the default 'postgres' db @@ -236,6 +243,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): return False def init_connection_state(self): + super().init_connection_state() self.connection.set_client_encoding("UTF8") timezone_changed = self.ensure_timezone() diff --git a/django/db/backends/postgresql/features.py b/django/db/backends/postgresql/features.py index 7c5d09d193..5e6752b97a 100644 --- a/django/db/backends/postgresql/features.py +++ b/django/db/backends/postgresql/features.py @@ -6,6 +6,7 @@ from django.utils.functional import cached_property class DatabaseFeatures(BaseDatabaseFeatures): + minimum_database_version = (10,) allows_group_by_selected_pks = True can_return_columns_from_insert = True can_return_rows_from_bulk_insert = True diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index 5bcd61eb96..8ca076a1d9 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -29,15 +29,6 @@ def decoder(conv_func): return lambda s: conv_func(s.decode()) -def check_sqlite_version(): - if Database.sqlite_version_info < (3, 9, 0): - raise ImproperlyConfigured( - "SQLite 3.9.0 or later is required (found %s)." % Database.sqlite_version - ) - - -check_sqlite_version() - Database.register_converter("bool", b"1".__eq__) Database.register_converter("time", decoder(parse_time)) Database.register_converter("datetime", decoder(parse_datetime)) @@ -168,6 +159,9 @@ class DatabaseWrapper(BaseDatabaseWrapper): kwargs.update({"check_same_thread": False, "uri": True}) return kwargs + def get_database_version(self): + return self.Database.sqlite_version_info + @async_unsafe def get_new_connection(self, conn_params): conn = Database.connect(**conn_params) @@ -179,9 +173,6 @@ class DatabaseWrapper(BaseDatabaseWrapper): conn.execute("PRAGMA legacy_alter_table = OFF") return conn - def init_connection_state(self): - pass - def create_cursor(self, name=None): return self.connection.cursor(factory=SQLiteCursorWrapper) diff --git a/django/db/backends/sqlite3/features.py b/django/db/backends/sqlite3/features.py index 9161ae3133..2886ecc3be 100644 --- a/django/db/backends/sqlite3/features.py +++ b/django/db/backends/sqlite3/features.py @@ -9,6 +9,7 @@ from .base import Database class DatabaseFeatures(BaseDatabaseFeatures): + minimum_database_version = (3, 9) test_db_allows_multiple_connections = False supports_unspecified_pk = True supports_timezones = False diff --git a/docs/releases/4.1.txt b/docs/releases/4.1.txt index 88d283ff65..936350979f 100644 --- a/docs/releases/4.1.txt +++ b/docs/releases/4.1.txt @@ -157,6 +157,18 @@ CSRF * ... +Database backends +~~~~~~~~~~~~~~~~~ + +* Third-party database backends can now specify the minimum required version of + the database using the ``DatabaseFeatures.minimum_database_version`` + attribute which is a tuple (e.g. ``(10, 0)`` means "10.0"). If a minimum + version is specified, backends must also implement + ``DatabaseWrapper.get_database_version()``, which returns a tuple of the + current database version. The backend's + ``DatabaseWrapper.init_connection_state()`` method must call ``super()`` in + order for the check to run. + Decorators ~~~~~~~~~~ diff --git a/tests/backends/base/test_base.py b/tests/backends/base/test_base.py index 00ef766d5d..57d22ce269 100644 --- a/tests/backends/base/test_base.py +++ b/tests/backends/base/test_base.py @@ -41,6 +41,19 @@ class DatabaseWrapperTests(SimpleTestCase): self.assertEqual(BaseDatabaseWrapper.display_name, "unknown") self.assertNotEqual(connection.display_name, "unknown") + def test_get_database_version(self): + with patch.object(BaseDatabaseWrapper, "__init__", return_value=None): + msg = ( + "subclasses of BaseDatabaseWrapper may require a " + "get_database_version() method." + ) + with self.assertRaisesMessage(NotImplementedError, msg): + BaseDatabaseWrapper().get_database_version() + + def test_check_database_version_supported_with_none_as_database_version(self): + with patch.object(connection.features, "minimum_database_version", None): + connection.check_database_version_supported() + class ExecuteWrapperTests(TestCase): @staticmethod @@ -297,3 +310,25 @@ class ConnectionHealthChecksTests(SimpleTestCase): connection.commit() connection.set_autocommit(True) self.assertIs(new_connection, connection.connection) + + +class MultiDatabaseTests(TestCase): + databases = {"default", "other"} + + def test_multi_database_init_connection_state_called_once(self): + for db in self.databases: + with self.subTest(database=db): + with patch.object(connections[db], "commit", return_value=None): + with patch.object( + connections[db], + "check_database_version_supported", + ) as mocked_check_database_version_supported: + connections[db].init_connection_state() + after_first_calls = len( + mocked_check_database_version_supported.mock_calls + ) + connections[db].init_connection_state() + self.assertEqual( + len(mocked_check_database_version_supported.mock_calls), + after_first_calls, + ) diff --git a/tests/backends/mysql/tests.py b/tests/backends/mysql/tests.py index 139b363bf4..e84762584a 100644 --- a/tests/backends/mysql/tests.py +++ b/tests/backends/mysql/tests.py @@ -1,8 +1,9 @@ import unittest from contextlib import contextmanager +from unittest import mock from django.core.exceptions import ImproperlyConfigured -from django.db import connection +from django.db import NotSupportedError, connection from django.test import TestCase, override_settings @@ -99,3 +100,19 @@ class IsolationLevelTests(TestCase): ) with self.assertRaisesMessage(ImproperlyConfigured, msg): new_connection.cursor() + + +@unittest.skipUnless(connection.vendor == "mysql", "MySQL tests") +class Tests(TestCase): + @mock.patch.object(connection, "get_database_version") + def test_check_database_version_supported(self, mocked_get_database_version): + if connection.mysql_is_mariadb: + mocked_get_database_version.return_value = (10, 1) + msg = "MariaDB 10.2 or later is required (found 10.1)." + else: + mocked_get_database_version.return_value = (5, 6) + msg = "MySQL 5.7 or later is required (found 5.6)." + + with self.assertRaisesMessage(NotSupportedError, msg): + connection.check_database_version_supported() + self.assertTrue(mocked_get_database_version.called) diff --git a/tests/backends/oracle/tests.py b/tests/backends/oracle/tests.py index 3f51d57de8..9a4e8ad435 100644 --- a/tests/backends/oracle/tests.py +++ b/tests/backends/oracle/tests.py @@ -1,14 +1,15 @@ import unittest +from unittest import mock -from django.db import DatabaseError, connection +from django.db import DatabaseError, NotSupportedError, connection from django.db.models import BooleanField -from django.test import TransactionTestCase +from django.test import TestCase, TransactionTestCase from ..models import Square, VeryLongModelNameZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ @unittest.skipUnless(connection.vendor == "oracle", "Oracle tests") -class Tests(unittest.TestCase): +class Tests(TestCase): def test_quote_name(self): """'%' chars are escaped for query execution.""" name = '"SOME%NAME"' @@ -56,6 +57,17 @@ class Tests(unittest.TestCase): field.set_attributes_from_name("is_nice") self.assertIn('"IS_NICE" IN (0,1)', field.db_check(connection)) + @mock.patch.object( + connection, + "get_database_version", + return_value=(18, 1), + ) + def test_check_database_version_supported(self, mocked_get_database_version): + msg = "Oracle 19 or later is required (found 18.1)." + with self.assertRaisesMessage(NotSupportedError, msg): + connection.check_database_version_supported() + self.assertTrue(mocked_get_database_version.called) + @unittest.skipUnless(connection.vendor == "oracle", "Oracle tests") class TransactionalTests(TransactionTestCase): diff --git a/tests/backends/postgresql/tests.py b/tests/backends/postgresql/tests.py index af08f6f286..ce8ed7b4d5 100644 --- a/tests/backends/postgresql/tests.py +++ b/tests/backends/postgresql/tests.py @@ -4,7 +4,13 @@ from io import StringIO from unittest import mock from django.core.exceptions import ImproperlyConfigured -from django.db import DEFAULT_DB_ALIAS, DatabaseError, connection, connections +from django.db import ( + DEFAULT_DB_ALIAS, + DatabaseError, + NotSupportedError, + connection, + connections, +) from django.db.backends.base.base import BaseDatabaseWrapper from django.test import TestCase, override_settings @@ -303,3 +309,15 @@ class Tests(TestCase): [q["sql"] for q in connection.queries], [copy_expert_sql, "COPY django_session TO STDOUT"], ) + + def test_get_database_version(self): + new_connection = connection.copy() + new_connection.pg_version = 110009 + self.assertEqual(new_connection.get_database_version(), (11, 9)) + + @mock.patch.object(connection, "get_database_version", return_value=(9, 6)) + def test_check_database_version_supported(self, mocked_get_database_version): + msg = "PostgreSQL 10 or later is required (found 9.6)." + with self.assertRaisesMessage(NotSupportedError, msg): + connection.check_database_version_supported() + self.assertTrue(mocked_get_database_version.called) diff --git a/tests/backends/sqlite/tests.py b/tests/backends/sqlite/tests.py index e167e09dcf..97505eaa36 100644 --- a/tests/backends/sqlite/tests.py +++ b/tests/backends/sqlite/tests.py @@ -4,10 +4,8 @@ import tempfile import threading import unittest from pathlib import Path -from sqlite3 import dbapi2 from unittest import mock -from django.core.exceptions import ImproperlyConfigured from django.db import NotSupportedError, connection, transaction from django.db.models import Aggregate, Avg, CharField, StdDev, Sum, Variance from django.db.utils import ConnectionHandler @@ -21,28 +19,11 @@ from django.test.utils import isolate_apps from ..models import Author, Item, Object, Square -try: - from django.db.backends.sqlite3.base import check_sqlite_version -except ImproperlyConfigured: - # Ignore "SQLite is too old" when running tests on another database. - pass - @unittest.skipUnless(connection.vendor == "sqlite", "SQLite tests") class Tests(TestCase): longMessage = True - def test_check_sqlite_version(self): - msg = "SQLite 3.9.0 or later is required (found 3.8.11.1)." - with mock.patch.object( - dbapi2, "sqlite_version_info", (3, 8, 11, 1) - ), mock.patch.object( - dbapi2, "sqlite_version", "3.8.11.1" - ), self.assertRaisesMessage( - ImproperlyConfigured, msg - ): - check_sqlite_version() - def test_aggregation(self): """Raise NotSupportedError when aggregating on date/time fields.""" for aggregate in (Sum, Avg, Variance, StdDev): @@ -125,6 +106,13 @@ class Tests(TestCase): connections["default"].close() self.assertTrue(os.path.isfile(os.path.join(tmp, "test.db"))) + @mock.patch.object(connection, "get_database_version", return_value=(3, 8)) + def test_check_database_version_supported(self, mocked_get_database_version): + msg = "SQLite 3.9 or later is required (found 3.8)." + with self.assertRaisesMessage(NotSupportedError, msg): + connection.check_database_version_supported() + self.assertTrue(mocked_get_database_version.called) + @unittest.skipUnless(connection.vendor == "sqlite", "SQLite tests") @isolate_apps("backends")