diff --git a/tests/queries/test_explain.py b/tests/queries/test_explain.py index f8ec9f445d..1d17ef8191 100644 --- a/tests/queries/test_explain.py +++ b/tests/queries/test_explain.py @@ -20,8 +20,9 @@ class ExplainTests(TestCase): Tag.objects.filter(name="test").annotate(Count("children")), Tag.objects.filter(name="test").values_list("name"), Tag.objects.order_by().union(Tag.objects.order_by().filter(name="test")), - Tag.objects.select_for_update().filter(name="test"), ] + if connection.features.has_select_for_update: + querysets.append(Tag.objects.select_for_update().filter(name="test")) supported_formats = connection.features.supported_explain_formats all_formats = ( (None,) @@ -31,13 +32,19 @@ class ExplainTests(TestCase): for idx, queryset in enumerate(querysets): for format in all_formats: with self.subTest(format=format, queryset=idx): - with self.assertNumQueries(1), CaptureQueriesContext( - connection - ) as captured_queries: - result = queryset.explain(format=format) + with CaptureQueriesContext(connection) as captured_queries: + if queryset.query.select_for_update: + with transaction.atomic(): + result = queryset.explain(format=format) + else: + result = queryset.explain(format=format) + self.assertEqual(len(captured_queries), 1) self.assertTrue( - captured_queries[0]["sql"].startswith( - connection.ops.explain_prefix + any( + captured_query["sql"].startswith( + connection.ops.explain_prefix + ) + for captured_query in captured_queries ) ) self.assertIsInstance(result, str)