mirror of https://github.com/django/django.git
Fixed #33379 -- Added minimum database version checks.
Thanks Tim Graham for the review.
This commit is contained in:
parent
737542390a
commit
9ac3ef59f9
|
@ -13,7 +13,7 @@ except ImportError:
|
|||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db import DEFAULT_DB_ALIAS, DatabaseError
|
||||
from django.db import DEFAULT_DB_ALIAS, DatabaseError, NotSupportedError
|
||||
from django.db.backends import utils
|
||||
from django.db.backends.base.validation import BaseDatabaseValidation
|
||||
from django.db.backends.signals import connection_created
|
||||
|
@ -24,6 +24,7 @@ from django.utils.asyncio import async_unsafe
|
|||
from django.utils.functional import cached_property
|
||||
|
||||
NO_DB_ALIAS = "__no_db__"
|
||||
RAN_DB_VERSION_CHECK = set()
|
||||
|
||||
|
||||
# RemovedInDjango50Warning
|
||||
|
@ -185,6 +186,29 @@ class BaseDatabaseWrapper:
|
|||
)
|
||||
return list(self.queries_log)
|
||||
|
||||
def get_database_version(self):
|
||||
"""Return a tuple of the database's version."""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseWrapper may require a get_database_version() "
|
||||
"method."
|
||||
)
|
||||
|
||||
def check_database_version_supported(self):
|
||||
"""
|
||||
Raise an error if the database version isn't supported by this
|
||||
version of Django.
|
||||
"""
|
||||
if (
|
||||
self.features.minimum_database_version is not None
|
||||
and self.get_database_version() < self.features.minimum_database_version
|
||||
):
|
||||
db_version = ".".join(map(str, self.get_database_version()))
|
||||
min_db_version = ".".join(map(str, self.features.minimum_database_version))
|
||||
raise NotSupportedError(
|
||||
f"{self.display_name} {min_db_version} or later is required "
|
||||
f"(found {db_version})."
|
||||
)
|
||||
|
||||
# ##### Backend-specific methods for creating connections and cursors #####
|
||||
|
||||
def get_connection_params(self):
|
||||
|
@ -203,10 +227,10 @@ class BaseDatabaseWrapper:
|
|||
|
||||
def init_connection_state(self):
|
||||
"""Initialize the database connection settings."""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseWrapper may require an init_connection_state() "
|
||||
"method"
|
||||
)
|
||||
global RAN_DB_VERSION_CHECK
|
||||
if self.alias not in RAN_DB_VERSION_CHECK:
|
||||
self.check_database_version_supported()
|
||||
RAN_DB_VERSION_CHECK.add(self.alias)
|
||||
|
||||
def create_cursor(self, name=None):
|
||||
"""Create a cursor. Assume that a connection is established."""
|
||||
|
|
|
@ -3,6 +3,8 @@ from django.utils.functional import cached_property
|
|||
|
||||
|
||||
class BaseDatabaseFeatures:
|
||||
# An optional tuple indicating the minimum supported database version.
|
||||
minimum_database_version = None
|
||||
gis_enabled = False
|
||||
# Oracle can't group by LOB (large object) data types.
|
||||
allows_group_by_lob = True
|
||||
|
|
|
@ -200,6 +200,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
|||
ops_class = DatabaseOperations
|
||||
validation_class = DatabaseValidation
|
||||
|
||||
def get_database_version(self):
|
||||
return self.mysql_version
|
||||
|
||||
def get_connection_params(self):
|
||||
kwargs = {
|
||||
"conv": django_conversions,
|
||||
|
@ -251,6 +254,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
|||
return connection
|
||||
|
||||
def init_connection_state(self):
|
||||
super().init_connection_state()
|
||||
assignments = []
|
||||
if self.features.is_sql_auto_is_null_enabled:
|
||||
# SQL_AUTO_IS_NULL controls whether an AUTO_INCREMENT column on
|
||||
|
|
|
@ -48,6 +48,13 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
|||
supports_order_by_nulls_modifier = False
|
||||
order_by_nulls_first = True
|
||||
|
||||
@cached_property
|
||||
def minimum_database_version(self):
|
||||
if self.connection.mysql_is_mariadb:
|
||||
return (10, 2)
|
||||
else:
|
||||
return (5, 7)
|
||||
|
||||
@cached_property
|
||||
def test_collations(self):
|
||||
charset = "utf8"
|
||||
|
|
|
@ -239,6 +239,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
|||
)
|
||||
self.features.can_return_columns_from_insert = use_returning_into
|
||||
|
||||
def get_database_version(self):
|
||||
return self.oracle_version
|
||||
|
||||
def get_connection_params(self):
|
||||
conn_params = self.settings_dict["OPTIONS"].copy()
|
||||
if "use_returning_into" in conn_params:
|
||||
|
@ -255,6 +258,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
|||
)
|
||||
|
||||
def init_connection_state(self):
|
||||
super().init_connection_state()
|
||||
cursor = self.create_cursor()
|
||||
# Set the territory first. The territory overrides NLS_DATE_FORMAT
|
||||
# and NLS_TIMESTAMP_FORMAT to the territory default. When all of
|
||||
|
|
|
@ -4,6 +4,7 @@ from django.utils.functional import cached_property
|
|||
|
||||
|
||||
class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
minimum_database_version = (19,)
|
||||
# Oracle crashes with "ORA-00932: inconsistent datatypes: expected - got
|
||||
# BLOB" when grouping by LOBs (#24096).
|
||||
allows_group_by_lob = False
|
||||
|
|
|
@ -153,6 +153,13 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
|||
# PostgreSQL backend-specific attributes.
|
||||
_named_cursor_idx = 0
|
||||
|
||||
def get_database_version(self):
|
||||
"""
|
||||
Return a tuple of the database's version.
|
||||
E.g. for pg_version 120004, return (12, 4).
|
||||
"""
|
||||
return divmod(self.pg_version, 10000)
|
||||
|
||||
def get_connection_params(self):
|
||||
settings_dict = self.settings_dict
|
||||
# None may be used to connect to the default 'postgres' db
|
||||
|
@ -236,6 +243,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
|||
return False
|
||||
|
||||
def init_connection_state(self):
|
||||
super().init_connection_state()
|
||||
self.connection.set_client_encoding("UTF8")
|
||||
|
||||
timezone_changed = self.ensure_timezone()
|
||||
|
|
|
@ -6,6 +6,7 @@ from django.utils.functional import cached_property
|
|||
|
||||
|
||||
class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
minimum_database_version = (10,)
|
||||
allows_group_by_selected_pks = True
|
||||
can_return_columns_from_insert = True
|
||||
can_return_rows_from_bulk_insert = True
|
||||
|
|
|
@ -29,15 +29,6 @@ def decoder(conv_func):
|
|||
return lambda s: conv_func(s.decode())
|
||||
|
||||
|
||||
def check_sqlite_version():
|
||||
if Database.sqlite_version_info < (3, 9, 0):
|
||||
raise ImproperlyConfigured(
|
||||
"SQLite 3.9.0 or later is required (found %s)." % Database.sqlite_version
|
||||
)
|
||||
|
||||
|
||||
check_sqlite_version()
|
||||
|
||||
Database.register_converter("bool", b"1".__eq__)
|
||||
Database.register_converter("time", decoder(parse_time))
|
||||
Database.register_converter("datetime", decoder(parse_datetime))
|
||||
|
@ -168,6 +159,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
|||
kwargs.update({"check_same_thread": False, "uri": True})
|
||||
return kwargs
|
||||
|
||||
def get_database_version(self):
|
||||
return self.Database.sqlite_version_info
|
||||
|
||||
@async_unsafe
|
||||
def get_new_connection(self, conn_params):
|
||||
conn = Database.connect(**conn_params)
|
||||
|
@ -179,9 +173,6 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
|||
conn.execute("PRAGMA legacy_alter_table = OFF")
|
||||
return conn
|
||||
|
||||
def init_connection_state(self):
|
||||
pass
|
||||
|
||||
def create_cursor(self, name=None):
|
||||
return self.connection.cursor(factory=SQLiteCursorWrapper)
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ from .base import Database
|
|||
|
||||
|
||||
class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
minimum_database_version = (3, 9)
|
||||
test_db_allows_multiple_connections = False
|
||||
supports_unspecified_pk = True
|
||||
supports_timezones = False
|
||||
|
|
|
@ -157,6 +157,18 @@ CSRF
|
|||
|
||||
* ...
|
||||
|
||||
Database backends
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
* Third-party database backends can now specify the minimum required version of
|
||||
the database using the ``DatabaseFeatures.minimum_database_version``
|
||||
attribute which is a tuple (e.g. ``(10, 0)`` means "10.0"). If a minimum
|
||||
version is specified, backends must also implement
|
||||
``DatabaseWrapper.get_database_version()``, which returns a tuple of the
|
||||
current database version. The backend's
|
||||
``DatabaseWrapper.init_connection_state()`` method must call ``super()`` in
|
||||
order for the check to run.
|
||||
|
||||
Decorators
|
||||
~~~~~~~~~~
|
||||
|
||||
|
|
|
@ -41,6 +41,19 @@ class DatabaseWrapperTests(SimpleTestCase):
|
|||
self.assertEqual(BaseDatabaseWrapper.display_name, "unknown")
|
||||
self.assertNotEqual(connection.display_name, "unknown")
|
||||
|
||||
def test_get_database_version(self):
|
||||
with patch.object(BaseDatabaseWrapper, "__init__", return_value=None):
|
||||
msg = (
|
||||
"subclasses of BaseDatabaseWrapper may require a "
|
||||
"get_database_version() method."
|
||||
)
|
||||
with self.assertRaisesMessage(NotImplementedError, msg):
|
||||
BaseDatabaseWrapper().get_database_version()
|
||||
|
||||
def test_check_database_version_supported_with_none_as_database_version(self):
|
||||
with patch.object(connection.features, "minimum_database_version", None):
|
||||
connection.check_database_version_supported()
|
||||
|
||||
|
||||
class ExecuteWrapperTests(TestCase):
|
||||
@staticmethod
|
||||
|
@ -297,3 +310,25 @@ class ConnectionHealthChecksTests(SimpleTestCase):
|
|||
connection.commit()
|
||||
connection.set_autocommit(True)
|
||||
self.assertIs(new_connection, connection.connection)
|
||||
|
||||
|
||||
class MultiDatabaseTests(TestCase):
|
||||
databases = {"default", "other"}
|
||||
|
||||
def test_multi_database_init_connection_state_called_once(self):
|
||||
for db in self.databases:
|
||||
with self.subTest(database=db):
|
||||
with patch.object(connections[db], "commit", return_value=None):
|
||||
with patch.object(
|
||||
connections[db],
|
||||
"check_database_version_supported",
|
||||
) as mocked_check_database_version_supported:
|
||||
connections[db].init_connection_state()
|
||||
after_first_calls = len(
|
||||
mocked_check_database_version_supported.mock_calls
|
||||
)
|
||||
connections[db].init_connection_state()
|
||||
self.assertEqual(
|
||||
len(mocked_check_database_version_supported.mock_calls),
|
||||
after_first_calls,
|
||||
)
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import unittest
|
||||
from contextlib import contextmanager
|
||||
from unittest import mock
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db import connection
|
||||
from django.db import NotSupportedError, connection
|
||||
from django.test import TestCase, override_settings
|
||||
|
||||
|
||||
|
@ -99,3 +100,19 @@ class IsolationLevelTests(TestCase):
|
|||
)
|
||||
with self.assertRaisesMessage(ImproperlyConfigured, msg):
|
||||
new_connection.cursor()
|
||||
|
||||
|
||||
@unittest.skipUnless(connection.vendor == "mysql", "MySQL tests")
|
||||
class Tests(TestCase):
|
||||
@mock.patch.object(connection, "get_database_version")
|
||||
def test_check_database_version_supported(self, mocked_get_database_version):
|
||||
if connection.mysql_is_mariadb:
|
||||
mocked_get_database_version.return_value = (10, 1)
|
||||
msg = "MariaDB 10.2 or later is required (found 10.1)."
|
||||
else:
|
||||
mocked_get_database_version.return_value = (5, 6)
|
||||
msg = "MySQL 5.7 or later is required (found 5.6)."
|
||||
|
||||
with self.assertRaisesMessage(NotSupportedError, msg):
|
||||
connection.check_database_version_supported()
|
||||
self.assertTrue(mocked_get_database_version.called)
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
from django.db import DatabaseError, connection
|
||||
from django.db import DatabaseError, NotSupportedError, connection
|
||||
from django.db.models import BooleanField
|
||||
from django.test import TransactionTestCase
|
||||
from django.test import TestCase, TransactionTestCase
|
||||
|
||||
from ..models import Square, VeryLongModelNameZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ
|
||||
|
||||
|
||||
@unittest.skipUnless(connection.vendor == "oracle", "Oracle tests")
|
||||
class Tests(unittest.TestCase):
|
||||
class Tests(TestCase):
|
||||
def test_quote_name(self):
|
||||
"""'%' chars are escaped for query execution."""
|
||||
name = '"SOME%NAME"'
|
||||
|
@ -56,6 +57,17 @@ class Tests(unittest.TestCase):
|
|||
field.set_attributes_from_name("is_nice")
|
||||
self.assertIn('"IS_NICE" IN (0,1)', field.db_check(connection))
|
||||
|
||||
@mock.patch.object(
|
||||
connection,
|
||||
"get_database_version",
|
||||
return_value=(18, 1),
|
||||
)
|
||||
def test_check_database_version_supported(self, mocked_get_database_version):
|
||||
msg = "Oracle 19 or later is required (found 18.1)."
|
||||
with self.assertRaisesMessage(NotSupportedError, msg):
|
||||
connection.check_database_version_supported()
|
||||
self.assertTrue(mocked_get_database_version.called)
|
||||
|
||||
|
||||
@unittest.skipUnless(connection.vendor == "oracle", "Oracle tests")
|
||||
class TransactionalTests(TransactionTestCase):
|
||||
|
|
|
@ -4,7 +4,13 @@ from io import StringIO
|
|||
from unittest import mock
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db import DEFAULT_DB_ALIAS, DatabaseError, connection, connections
|
||||
from django.db import (
|
||||
DEFAULT_DB_ALIAS,
|
||||
DatabaseError,
|
||||
NotSupportedError,
|
||||
connection,
|
||||
connections,
|
||||
)
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.test import TestCase, override_settings
|
||||
|
||||
|
@ -303,3 +309,15 @@ class Tests(TestCase):
|
|||
[q["sql"] for q in connection.queries],
|
||||
[copy_expert_sql, "COPY django_session TO STDOUT"],
|
||||
)
|
||||
|
||||
def test_get_database_version(self):
|
||||
new_connection = connection.copy()
|
||||
new_connection.pg_version = 110009
|
||||
self.assertEqual(new_connection.get_database_version(), (11, 9))
|
||||
|
||||
@mock.patch.object(connection, "get_database_version", return_value=(9, 6))
|
||||
def test_check_database_version_supported(self, mocked_get_database_version):
|
||||
msg = "PostgreSQL 10 or later is required (found 9.6)."
|
||||
with self.assertRaisesMessage(NotSupportedError, msg):
|
||||
connection.check_database_version_supported()
|
||||
self.assertTrue(mocked_get_database_version.called)
|
||||
|
|
|
@ -4,10 +4,8 @@ import tempfile
|
|||
import threading
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from sqlite3 import dbapi2
|
||||
from unittest import mock
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db import NotSupportedError, connection, transaction
|
||||
from django.db.models import Aggregate, Avg, CharField, StdDev, Sum, Variance
|
||||
from django.db.utils import ConnectionHandler
|
||||
|
@ -21,28 +19,11 @@ from django.test.utils import isolate_apps
|
|||
|
||||
from ..models import Author, Item, Object, Square
|
||||
|
||||
try:
|
||||
from django.db.backends.sqlite3.base import check_sqlite_version
|
||||
except ImproperlyConfigured:
|
||||
# Ignore "SQLite is too old" when running tests on another database.
|
||||
pass
|
||||
|
||||
|
||||
@unittest.skipUnless(connection.vendor == "sqlite", "SQLite tests")
|
||||
class Tests(TestCase):
|
||||
longMessage = True
|
||||
|
||||
def test_check_sqlite_version(self):
|
||||
msg = "SQLite 3.9.0 or later is required (found 3.8.11.1)."
|
||||
with mock.patch.object(
|
||||
dbapi2, "sqlite_version_info", (3, 8, 11, 1)
|
||||
), mock.patch.object(
|
||||
dbapi2, "sqlite_version", "3.8.11.1"
|
||||
), self.assertRaisesMessage(
|
||||
ImproperlyConfigured, msg
|
||||
):
|
||||
check_sqlite_version()
|
||||
|
||||
def test_aggregation(self):
|
||||
"""Raise NotSupportedError when aggregating on date/time fields."""
|
||||
for aggregate in (Sum, Avg, Variance, StdDev):
|
||||
|
@ -125,6 +106,13 @@ class Tests(TestCase):
|
|||
connections["default"].close()
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp, "test.db")))
|
||||
|
||||
@mock.patch.object(connection, "get_database_version", return_value=(3, 8))
|
||||
def test_check_database_version_supported(self, mocked_get_database_version):
|
||||
msg = "SQLite 3.9 or later is required (found 3.8)."
|
||||
with self.assertRaisesMessage(NotSupportedError, msg):
|
||||
connection.check_database_version_supported()
|
||||
self.assertTrue(mocked_get_database_version.called)
|
||||
|
||||
|
||||
@unittest.skipUnless(connection.vendor == "sqlite", "SQLite tests")
|
||||
@isolate_apps("backends")
|
||||
|
|
Loading…
Reference in New Issue