diff --git a/monkey/infection_monkey/island_api_client/http_island_api_client.py b/monkey/infection_monkey/island_api_client/http_island_api_client.py index bf6ba9355..d1e13778b 100644 --- a/monkey/infection_monkey/island_api_client/http_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/http_island_api_client.py @@ -35,6 +35,14 @@ def handle_island_errors(fn): def decorated(*args, **kwargs): try: return fn(*args, **kwargs) + except ( + IslandAPIConnectionError, + IslandAPIError, + IslandAPIRequestError, + IslandAPIRequestFailedError, + IslandAPITimeoutError, + ) as e: + raise e except (requests.exceptions.ConnectionError, requests.exceptions.TooManyRedirects) as err: raise IslandAPIConnectionError(err) except requests.exceptions.HTTPError as err: diff --git a/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py b/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py index d4acc1c79..033936ba2 100644 --- a/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py +++ b/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py @@ -39,6 +39,7 @@ ISLAND_GET_PBA_FILE_URI = f"https://{SERVER}/api/pba/download/{PBA_FILE}" ISLAND_GET_AGENT_BINARY_URI = f"https://{SERVER}/api/agent-binaries/{WINDOWS}" ISLAND_SEND_EVENTS_URI = f"https://{SERVER}/api/agent-events" ISLAND_REGISTER_AGENT_URI = f"https://{SERVER}/api/agents" +ISLAND_AGENT_STOP_URI = f"https://{SERVER}/api/monkey-control/needs-to-stop/{AGENT_ID}" class Event1(AbstractAgentEvent): @@ -318,3 +319,47 @@ def test_island_api_client_register_agent__status_code(status_code, expected_err with pytest.raises(expected_error): m.post(ISLAND_REGISTER_AGENT_URI, status_code=status_code) island_api_client.register_agent(AGENT_REGISTRATION) + + +@pytest.mark.parametrize( + "actual_error, expected_error", + [ + (requests.exceptions.ConnectionError, IslandAPIConnectionError), + (TimeoutError, IslandAPITimeoutError), + ], +) +def test_island_api_client__should_agent_stop(actual_error, expected_error): + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + island_api_client = HTTPIslandAPIClient(SERVER) + + with pytest.raises(expected_error): + m.get(ISLAND_AGENT_STOP_URI, exc=actual_error) + island_api_client.should_agent_stop(AGENT_ID) + + +@pytest.mark.parametrize( + "status_code, expected_error", + [ + (401, IslandAPIRequestError), + (501, IslandAPIRequestFailedError), + ], +) +def test_island_api_client_should_agent_stop__status_code(status_code, expected_error): + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + island_api_client = HTTPIslandAPIClient(SERVER) + + with pytest.raises(expected_error): + m.get(ISLAND_AGENT_STOP_URI, status_code=status_code) + island_api_client.should_agent_stop(AGENT_ID) + + +def test_island_api_client_should_agent_stop__bad_json(): + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + island_api_client = HTTPIslandAPIClient(SERVER) + + with pytest.raises(IslandAPIRequestFailedError): + m.get(ISLAND_AGENT_STOP_URI, content=b"bad") + island_api_client.should_agent_stop(AGENT_ID)