Fixed #33379 -- Added minimum database version checks.

Thanks Tim Graham for the review.
This commit is contained in:
Hasan Ramezani 2021-12-27 19:04:59 +01:00 committed by Mariusz Felisiak
parent 737542390a
commit 9ac3ef59f9
16 changed files with 166 additions and 41 deletions

View File

@ -13,7 +13,7 @@ except ImportError:
from django.conf import settings from django.conf import settings
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.db import DEFAULT_DB_ALIAS, DatabaseError from django.db import DEFAULT_DB_ALIAS, DatabaseError, NotSupportedError
from django.db.backends import utils from django.db.backends import utils
from django.db.backends.base.validation import BaseDatabaseValidation from django.db.backends.base.validation import BaseDatabaseValidation
from django.db.backends.signals import connection_created from django.db.backends.signals import connection_created
@ -24,6 +24,7 @@ from django.utils.asyncio import async_unsafe
from django.utils.functional import cached_property from django.utils.functional import cached_property
NO_DB_ALIAS = "__no_db__" NO_DB_ALIAS = "__no_db__"
RAN_DB_VERSION_CHECK = set()
# RemovedInDjango50Warning # RemovedInDjango50Warning
@ -185,6 +186,29 @@ class BaseDatabaseWrapper:
) )
return list(self.queries_log) return list(self.queries_log)
def get_database_version(self):
"""Return a tuple of the database's version."""
raise NotImplementedError(
"subclasses of BaseDatabaseWrapper may require a get_database_version() "
"method."
)
def check_database_version_supported(self):
"""
Raise an error if the database version isn't supported by this
version of Django.
"""
if (
self.features.minimum_database_version is not None
and self.get_database_version() < self.features.minimum_database_version
):
db_version = ".".join(map(str, self.get_database_version()))
min_db_version = ".".join(map(str, self.features.minimum_database_version))
raise NotSupportedError(
f"{self.display_name} {min_db_version} or later is required "
f"(found {db_version})."
)
# ##### Backend-specific methods for creating connections and cursors ##### # ##### Backend-specific methods for creating connections and cursors #####
def get_connection_params(self): def get_connection_params(self):
@ -203,10 +227,10 @@ class BaseDatabaseWrapper:
def init_connection_state(self): def init_connection_state(self):
"""Initialize the database connection settings.""" """Initialize the database connection settings."""
raise NotImplementedError( global RAN_DB_VERSION_CHECK
"subclasses of BaseDatabaseWrapper may require an init_connection_state() " if self.alias not in RAN_DB_VERSION_CHECK:
"method" self.check_database_version_supported()
) RAN_DB_VERSION_CHECK.add(self.alias)
def create_cursor(self, name=None): def create_cursor(self, name=None):
"""Create a cursor. Assume that a connection is established.""" """Create a cursor. Assume that a connection is established."""

View File

@ -3,6 +3,8 @@ from django.utils.functional import cached_property
class BaseDatabaseFeatures: class BaseDatabaseFeatures:
# An optional tuple indicating the minimum supported database version.
minimum_database_version = None
gis_enabled = False gis_enabled = False
# Oracle can't group by LOB (large object) data types. # Oracle can't group by LOB (large object) data types.
allows_group_by_lob = True allows_group_by_lob = True

View File

@ -200,6 +200,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
ops_class = DatabaseOperations ops_class = DatabaseOperations
validation_class = DatabaseValidation validation_class = DatabaseValidation
def get_database_version(self):
return self.mysql_version
def get_connection_params(self): def get_connection_params(self):
kwargs = { kwargs = {
"conv": django_conversions, "conv": django_conversions,
@ -251,6 +254,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return connection return connection
def init_connection_state(self): def init_connection_state(self):
super().init_connection_state()
assignments = [] assignments = []
if self.features.is_sql_auto_is_null_enabled: if self.features.is_sql_auto_is_null_enabled:
# SQL_AUTO_IS_NULL controls whether an AUTO_INCREMENT column on # SQL_AUTO_IS_NULL controls whether an AUTO_INCREMENT column on

View File

@ -48,6 +48,13 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_order_by_nulls_modifier = False supports_order_by_nulls_modifier = False
order_by_nulls_first = True order_by_nulls_first = True
@cached_property
def minimum_database_version(self):
if self.connection.mysql_is_mariadb:
return (10, 2)
else:
return (5, 7)
@cached_property @cached_property
def test_collations(self): def test_collations(self):
charset = "utf8" charset = "utf8"

View File

@ -239,6 +239,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
) )
self.features.can_return_columns_from_insert = use_returning_into self.features.can_return_columns_from_insert = use_returning_into
def get_database_version(self):
return self.oracle_version
def get_connection_params(self): def get_connection_params(self):
conn_params = self.settings_dict["OPTIONS"].copy() conn_params = self.settings_dict["OPTIONS"].copy()
if "use_returning_into" in conn_params: if "use_returning_into" in conn_params:
@ -255,6 +258,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
) )
def init_connection_state(self): def init_connection_state(self):
super().init_connection_state()
cursor = self.create_cursor() cursor = self.create_cursor()
# Set the territory first. The territory overrides NLS_DATE_FORMAT # Set the territory first. The territory overrides NLS_DATE_FORMAT
# and NLS_TIMESTAMP_FORMAT to the territory default. When all of # and NLS_TIMESTAMP_FORMAT to the territory default. When all of

View File

@ -4,6 +4,7 @@ from django.utils.functional import cached_property
class DatabaseFeatures(BaseDatabaseFeatures): class DatabaseFeatures(BaseDatabaseFeatures):
minimum_database_version = (19,)
# Oracle crashes with "ORA-00932: inconsistent datatypes: expected - got # Oracle crashes with "ORA-00932: inconsistent datatypes: expected - got
# BLOB" when grouping by LOBs (#24096). # BLOB" when grouping by LOBs (#24096).
allows_group_by_lob = False allows_group_by_lob = False

View File

@ -153,6 +153,13 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# PostgreSQL backend-specific attributes. # PostgreSQL backend-specific attributes.
_named_cursor_idx = 0 _named_cursor_idx = 0
def get_database_version(self):
"""
Return a tuple of the database's version.
E.g. for pg_version 120004, return (12, 4).
"""
return divmod(self.pg_version, 10000)
def get_connection_params(self): def get_connection_params(self):
settings_dict = self.settings_dict settings_dict = self.settings_dict
# None may be used to connect to the default 'postgres' db # None may be used to connect to the default 'postgres' db
@ -236,6 +243,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return False return False
def init_connection_state(self): def init_connection_state(self):
super().init_connection_state()
self.connection.set_client_encoding("UTF8") self.connection.set_client_encoding("UTF8")
timezone_changed = self.ensure_timezone() timezone_changed = self.ensure_timezone()

View File

@ -6,6 +6,7 @@ from django.utils.functional import cached_property
class DatabaseFeatures(BaseDatabaseFeatures): class DatabaseFeatures(BaseDatabaseFeatures):
minimum_database_version = (10,)
allows_group_by_selected_pks = True allows_group_by_selected_pks = True
can_return_columns_from_insert = True can_return_columns_from_insert = True
can_return_rows_from_bulk_insert = True can_return_rows_from_bulk_insert = True

View File

@ -29,15 +29,6 @@ def decoder(conv_func):
return lambda s: conv_func(s.decode()) return lambda s: conv_func(s.decode())
def check_sqlite_version():
if Database.sqlite_version_info < (3, 9, 0):
raise ImproperlyConfigured(
"SQLite 3.9.0 or later is required (found %s)." % Database.sqlite_version
)
check_sqlite_version()
Database.register_converter("bool", b"1".__eq__) Database.register_converter("bool", b"1".__eq__)
Database.register_converter("time", decoder(parse_time)) Database.register_converter("time", decoder(parse_time))
Database.register_converter("datetime", decoder(parse_datetime)) Database.register_converter("datetime", decoder(parse_datetime))
@ -168,6 +159,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
kwargs.update({"check_same_thread": False, "uri": True}) kwargs.update({"check_same_thread": False, "uri": True})
return kwargs return kwargs
def get_database_version(self):
return self.Database.sqlite_version_info
@async_unsafe @async_unsafe
def get_new_connection(self, conn_params): def get_new_connection(self, conn_params):
conn = Database.connect(**conn_params) conn = Database.connect(**conn_params)
@ -179,9 +173,6 @@ class DatabaseWrapper(BaseDatabaseWrapper):
conn.execute("PRAGMA legacy_alter_table = OFF") conn.execute("PRAGMA legacy_alter_table = OFF")
return conn return conn
def init_connection_state(self):
pass
def create_cursor(self, name=None): def create_cursor(self, name=None):
return self.connection.cursor(factory=SQLiteCursorWrapper) return self.connection.cursor(factory=SQLiteCursorWrapper)

View File

@ -9,6 +9,7 @@ from .base import Database
class DatabaseFeatures(BaseDatabaseFeatures): class DatabaseFeatures(BaseDatabaseFeatures):
minimum_database_version = (3, 9)
test_db_allows_multiple_connections = False test_db_allows_multiple_connections = False
supports_unspecified_pk = True supports_unspecified_pk = True
supports_timezones = False supports_timezones = False

View File

@ -157,6 +157,18 @@ CSRF
* ... * ...
Database backends
~~~~~~~~~~~~~~~~~
* Third-party database backends can now specify the minimum required version of
the database using the ``DatabaseFeatures.minimum_database_version``
attribute which is a tuple (e.g. ``(10, 0)`` means "10.0"). If a minimum
version is specified, backends must also implement
``DatabaseWrapper.get_database_version()``, which returns a tuple of the
current database version. The backend's
``DatabaseWrapper.init_connection_state()`` method must call ``super()`` in
order for the check to run.
Decorators Decorators
~~~~~~~~~~ ~~~~~~~~~~

View File

@ -41,6 +41,19 @@ class DatabaseWrapperTests(SimpleTestCase):
self.assertEqual(BaseDatabaseWrapper.display_name, "unknown") self.assertEqual(BaseDatabaseWrapper.display_name, "unknown")
self.assertNotEqual(connection.display_name, "unknown") self.assertNotEqual(connection.display_name, "unknown")
def test_get_database_version(self):
with patch.object(BaseDatabaseWrapper, "__init__", return_value=None):
msg = (
"subclasses of BaseDatabaseWrapper may require a "
"get_database_version() method."
)
with self.assertRaisesMessage(NotImplementedError, msg):
BaseDatabaseWrapper().get_database_version()
def test_check_database_version_supported_with_none_as_database_version(self):
with patch.object(connection.features, "minimum_database_version", None):
connection.check_database_version_supported()
class ExecuteWrapperTests(TestCase): class ExecuteWrapperTests(TestCase):
@staticmethod @staticmethod
@ -297,3 +310,25 @@ class ConnectionHealthChecksTests(SimpleTestCase):
connection.commit() connection.commit()
connection.set_autocommit(True) connection.set_autocommit(True)
self.assertIs(new_connection, connection.connection) self.assertIs(new_connection, connection.connection)
class MultiDatabaseTests(TestCase):
databases = {"default", "other"}
def test_multi_database_init_connection_state_called_once(self):
for db in self.databases:
with self.subTest(database=db):
with patch.object(connections[db], "commit", return_value=None):
with patch.object(
connections[db],
"check_database_version_supported",
) as mocked_check_database_version_supported:
connections[db].init_connection_state()
after_first_calls = len(
mocked_check_database_version_supported.mock_calls
)
connections[db].init_connection_state()
self.assertEqual(
len(mocked_check_database_version_supported.mock_calls),
after_first_calls,
)

View File

@ -1,8 +1,9 @@
import unittest import unittest
from contextlib import contextmanager from contextlib import contextmanager
from unittest import mock
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.db import connection from django.db import NotSupportedError, connection
from django.test import TestCase, override_settings from django.test import TestCase, override_settings
@ -99,3 +100,19 @@ class IsolationLevelTests(TestCase):
) )
with self.assertRaisesMessage(ImproperlyConfigured, msg): with self.assertRaisesMessage(ImproperlyConfigured, msg):
new_connection.cursor() new_connection.cursor()
@unittest.skipUnless(connection.vendor == "mysql", "MySQL tests")
class Tests(TestCase):
@mock.patch.object(connection, "get_database_version")
def test_check_database_version_supported(self, mocked_get_database_version):
if connection.mysql_is_mariadb:
mocked_get_database_version.return_value = (10, 1)
msg = "MariaDB 10.2 or later is required (found 10.1)."
else:
mocked_get_database_version.return_value = (5, 6)
msg = "MySQL 5.7 or later is required (found 5.6)."
with self.assertRaisesMessage(NotSupportedError, msg):
connection.check_database_version_supported()
self.assertTrue(mocked_get_database_version.called)

View File

@ -1,14 +1,15 @@
import unittest import unittest
from unittest import mock
from django.db import DatabaseError, connection from django.db import DatabaseError, NotSupportedError, connection
from django.db.models import BooleanField from django.db.models import BooleanField
from django.test import TransactionTestCase from django.test import TestCase, TransactionTestCase
from ..models import Square, VeryLongModelNameZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ from ..models import Square, VeryLongModelNameZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ
@unittest.skipUnless(connection.vendor == "oracle", "Oracle tests") @unittest.skipUnless(connection.vendor == "oracle", "Oracle tests")
class Tests(unittest.TestCase): class Tests(TestCase):
def test_quote_name(self): def test_quote_name(self):
"""'%' chars are escaped for query execution.""" """'%' chars are escaped for query execution."""
name = '"SOME%NAME"' name = '"SOME%NAME"'
@ -56,6 +57,17 @@ class Tests(unittest.TestCase):
field.set_attributes_from_name("is_nice") field.set_attributes_from_name("is_nice")
self.assertIn('"IS_NICE" IN (0,1)', field.db_check(connection)) self.assertIn('"IS_NICE" IN (0,1)', field.db_check(connection))
@mock.patch.object(
connection,
"get_database_version",
return_value=(18, 1),
)
def test_check_database_version_supported(self, mocked_get_database_version):
msg = "Oracle 19 or later is required (found 18.1)."
with self.assertRaisesMessage(NotSupportedError, msg):
connection.check_database_version_supported()
self.assertTrue(mocked_get_database_version.called)
@unittest.skipUnless(connection.vendor == "oracle", "Oracle tests") @unittest.skipUnless(connection.vendor == "oracle", "Oracle tests")
class TransactionalTests(TransactionTestCase): class TransactionalTests(TransactionTestCase):

View File

@ -4,7 +4,13 @@ from io import StringIO
from unittest import mock from unittest import mock
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.db import DEFAULT_DB_ALIAS, DatabaseError, connection, connections from django.db import (
DEFAULT_DB_ALIAS,
DatabaseError,
NotSupportedError,
connection,
connections,
)
from django.db.backends.base.base import BaseDatabaseWrapper from django.db.backends.base.base import BaseDatabaseWrapper
from django.test import TestCase, override_settings from django.test import TestCase, override_settings
@ -303,3 +309,15 @@ class Tests(TestCase):
[q["sql"] for q in connection.queries], [q["sql"] for q in connection.queries],
[copy_expert_sql, "COPY django_session TO STDOUT"], [copy_expert_sql, "COPY django_session TO STDOUT"],
) )
def test_get_database_version(self):
new_connection = connection.copy()
new_connection.pg_version = 110009
self.assertEqual(new_connection.get_database_version(), (11, 9))
@mock.patch.object(connection, "get_database_version", return_value=(9, 6))
def test_check_database_version_supported(self, mocked_get_database_version):
msg = "PostgreSQL 10 or later is required (found 9.6)."
with self.assertRaisesMessage(NotSupportedError, msg):
connection.check_database_version_supported()
self.assertTrue(mocked_get_database_version.called)

View File

@ -4,10 +4,8 @@ import tempfile
import threading import threading
import unittest import unittest
from pathlib import Path from pathlib import Path
from sqlite3 import dbapi2
from unittest import mock from unittest import mock
from django.core.exceptions import ImproperlyConfigured
from django.db import NotSupportedError, connection, transaction from django.db import NotSupportedError, connection, transaction
from django.db.models import Aggregate, Avg, CharField, StdDev, Sum, Variance from django.db.models import Aggregate, Avg, CharField, StdDev, Sum, Variance
from django.db.utils import ConnectionHandler from django.db.utils import ConnectionHandler
@ -21,28 +19,11 @@ from django.test.utils import isolate_apps
from ..models import Author, Item, Object, Square from ..models import Author, Item, Object, Square
try:
from django.db.backends.sqlite3.base import check_sqlite_version
except ImproperlyConfigured:
# Ignore "SQLite is too old" when running tests on another database.
pass
@unittest.skipUnless(connection.vendor == "sqlite", "SQLite tests") @unittest.skipUnless(connection.vendor == "sqlite", "SQLite tests")
class Tests(TestCase): class Tests(TestCase):
longMessage = True longMessage = True
def test_check_sqlite_version(self):
msg = "SQLite 3.9.0 or later is required (found 3.8.11.1)."
with mock.patch.object(
dbapi2, "sqlite_version_info", (3, 8, 11, 1)
), mock.patch.object(
dbapi2, "sqlite_version", "3.8.11.1"
), self.assertRaisesMessage(
ImproperlyConfigured, msg
):
check_sqlite_version()
def test_aggregation(self): def test_aggregation(self):
"""Raise NotSupportedError when aggregating on date/time fields.""" """Raise NotSupportedError when aggregating on date/time fields."""
for aggregate in (Sum, Avg, Variance, StdDev): for aggregate in (Sum, Avg, Variance, StdDev):
@ -125,6 +106,13 @@ class Tests(TestCase):
connections["default"].close() connections["default"].close()
self.assertTrue(os.path.isfile(os.path.join(tmp, "test.db"))) self.assertTrue(os.path.isfile(os.path.join(tmp, "test.db")))
@mock.patch.object(connection, "get_database_version", return_value=(3, 8))
def test_check_database_version_supported(self, mocked_get_database_version):
msg = "SQLite 3.9 or later is required (found 3.8)."
with self.assertRaisesMessage(NotSupportedError, msg):
connection.check_database_version_supported()
self.assertTrue(mocked_get_database_version.called)
@unittest.skipUnless(connection.vendor == "sqlite", "SQLite tests") @unittest.skipUnless(connection.vendor == "sqlite", "SQLite tests")
@isolate_apps("backends") @isolate_apps("backends")