diff --git a/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py b/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py index bf4d70056..cae3cacbd 100644 --- a/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py +++ b/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py @@ -1,5 +1,6 @@ import pytest import requests +import requests_mock from infection_monkey.network.relay.utils import find_server @@ -9,27 +10,18 @@ SERVER_3 = "3.3.3.3:3142" SERVER_4 = "4.4.4.4:5000" -class MockConnectionError: - def __init__(self, *args, **kwargs): - raise requests.exceptions.ConnectionError - - -class MockRequestsGetResponsePerServerArgument: - def __init__(self, *args, **kwargs): - if SERVER_1 in args[0]: - MockConnectionError() - - -@pytest.fixture -def servers(): - return [SERVER_1, SERVER_2, SERVER_3, SERVER_4] +servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4] @pytest.mark.parametrize( - "mock_requests_get, expected", - [(MockConnectionError, None), (MockRequestsGetResponsePerServerArgument, SERVER_2)], + "expected_server,connection_error_servers,do_nothing_servers", + [(None, servers, []), (SERVER_2, [SERVER_1], [SERVER_2, SERVER_3, SERVER_4])], ) -def test_find_server(monkeypatch, servers, mock_requests_get, expected): - monkeypatch.setattr("infection_monkey.control.requests.get", mock_requests_get) +def test_find_server(expected_server, connection_error_servers, do_nothing_servers): + with requests_mock.Mocker() as mock: + for server in connection_error_servers: + mock.get(f"https://{server}/api?action=is-up", exc=requests.exceptions.ConnectionError) + for server in do_nothing_servers: + mock.get(f"https://{server}/api?action=is-up", text="") - assert find_server(servers) is expected + assert find_server(servers) is expected_server