diff --git a/monkey/tests/unit_tests/monkey_island/cc/resources/test_island_mode.py b/monkey/tests/unit_tests/monkey_island/cc/resources/test_island_mode.py index 7d48e3b2f..ad099cd9c 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/resources/test_island_mode.py +++ b/monkey/tests/unit_tests/monkey_island/cc/resources/test_island_mode.py @@ -12,20 +12,35 @@ from monkey_island.cc.resources.island_mode import IslandMode as IslandModeResou @pytest.fixture -def flask_client(build_flask_client): - container = StubDIContainer() +def flask_client_builder(build_flask_client): + def inner(side_effect=None): + container = StubDIContainer() - in_memory_simulation_repository = InMemorySimulationRepository() - container.register_instance(ISimulationRepository, in_memory_simulation_repository) + in_memory_simulation_repository = InMemorySimulationRepository() + container.register_instance(ISimulationRepository, in_memory_simulation_repository) - mock_island_event_queue = MagicMock(spec=IIslandEventQueue) - mock_island_event_queue.publish.side_effect = ( - lambda topic, mode: in_memory_simulation_repository.set_mode(mode) - ) - container.register_instance(IIslandEventQueue, mock_island_event_queue) + mock_island_event_queue = MagicMock(spec=IIslandEventQueue) + mock_island_event_queue.publish.side_effect = ( + side_effect + if side_effect + else lambda topic, mode: in_memory_simulation_repository.set_mode(mode) + ) + container.register_instance(IIslandEventQueue, mock_island_event_queue) - with build_flask_client(container) as flask_client: - yield flask_client + with build_flask_client(container) as flask_client: + return flask_client + + return inner + + +@pytest.fixture +def flask_client(flask_client_builder): + return flask_client_builder() + + +@pytest.fixture +def flask_client__internal_server_error(flask_client_builder): + return flask_client_builder(Exception) @pytest.mark.parametrize( @@ -50,22 +65,12 @@ def test_island_mode_post__invalid_mode(flask_client): assert resp.status_code == HTTPStatus.UNPROCESSABLE_ENTITY -def test_island_mode_post__internal_server_error(build_flask_client): - container = StubDIContainer() - in_memory_simulation_repository = InMemorySimulationRepository() - container.register_instance(ISimulationRepository, in_memory_simulation_repository) - - mock_island_event_queue = MagicMock(spec=IIslandEventQueue) - mock_island_event_queue.publish.side_effect = Exception - container.register_instance(IIslandEventQueue, mock_island_event_queue) - - with build_flask_client(container) as flask_client: - resp = flask_client.put( - IslandModeResource.urls[0], - json=IslandMode.RANSOMWARE.value, - follow_redirects=True, - ) - +def test_island_mode_post__internal_server_error(flask_client__internal_server_error): + resp = flask_client__internal_server_error.put( + IslandModeResource.urls[0], + json=IslandMode.RANSOMWARE.value, + follow_redirects=True, + ) assert resp.status_code == HTTPStatus.INTERNAL_SERVER_ERROR