diff --git a/monkey/infection_monkey/island_api_client/http_island_api_client.py b/monkey/infection_monkey/island_api_client/http_island_api_client.py index e5e338c4e..17e2f9284 100644 --- a/monkey/infection_monkey/island_api_client/http_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/http_island_api_client.py @@ -35,7 +35,7 @@ def handle_island_errors(fn): def decorated(*args, **kwargs): try: return fn(*args, **kwargs) - except requests.exceptions.ConnectionError as err: + except (requests.exceptions.ConnectionError, requests.exceptions.TooManyRedirects) as err: raise IslandAPIConnectionError(err) except requests.exceptions.HTTPError as err: if 400 <= err.response.status_code < 500: @@ -54,6 +54,17 @@ def handle_island_errors(fn): return decorated +def convert_json_error_to_island_api_error(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + try: + fn(*args, **kwargs) + except json.JSONDecodeError as e: + raise IslandAPIRequestFailedError(e) + + return wrapper + + class HTTPIslandAPIClient(IIslandAPIClient): """ A client for the Island's HTTP API @@ -125,104 +136,56 @@ class HTTPIslandAPIClient(IIslandAPIClient): response.raise_for_status() def register_agent(self, agent_registration_data: AgentRegistrationData): - try: - url = f"https://{agent_registration_data.cc_server}/api/agents" - response = requests.post( # noqa: DUO123 - url, - json=agent_registration_data.dict(simplify=True), - verify=False, - timeout=SHORT_REQUEST_TIMEOUT, - ) - response.raise_for_status() - except ( - requests.exceptions.ConnectionError, - requests.exceptions.TooManyRedirects, - requests.exceptions.HTTPError, - ) as e: - raise IslandAPIConnectionError(e) - except requests.exceptions.Timeout as e: - raise IslandAPITimeoutError(e) + url = f"https://{agent_registration_data.cc_server}/api/agents" + response = requests.post( # noqa: DUO123 + url, + json=agent_registration_data.dict(simplify=True), + verify=False, + timeout=SHORT_REQUEST_TIMEOUT, + ) + response.raise_for_status() + @handle_island_errors + @convert_json_error_to_island_api_error def should_agent_stop(self, island_server: str, agent_id: str) -> bool: - try: - url = f"https://{island_server}/api/monkey-control" f"/needs-to-stop/{agent_id}" - response = requests.get( # noqa: DUO123 - url, - verify=False, - timeout=SHORT_REQUEST_TIMEOUT, - ) - response.raise_for_status() + url = f"https://{island_server}/api/monkey-control" f"/needs-to-stop/{agent_id}" + response = requests.get( # noqa: DUO123 + url, + verify=False, + timeout=SHORT_REQUEST_TIMEOUT, + ) + response.raise_for_status() - json_response = json.loads(response.content.decode()) - return json_response["stop_agent"] - except ( - requests.exceptions.ConnectionError, - requests.exceptions.TooManyRedirects, - ) as e: - raise IslandAPIConnectionError(e) - except requests.exceptions.Timeout as e: - raise IslandAPITimeoutError(e) - except requests.exceptions.HTTPError as e: - if e.errno >= 500: - raise IslandAPIRequestFailedError(e) - else: - raise IslandAPIRequestError(e) - except json.JSONDecodeError as e: - raise IslandAPIRequestFailedError(e) + json_response = json.loads(response.content.decode()) + return json_response["stop_agent"] + @handle_island_errors + @convert_json_error_to_island_api_error def get_config(self, island_server: str) -> AgentConfiguration: - try: - response = requests.get( # noqa: DUO123 - f"https://{island_server}/api/agent-configuration", - verify=False, - timeout=SHORT_REQUEST_TIMEOUT, - ) - response.raise_for_status() + response = requests.get( # noqa: DUO123 + f"https://{island_server}/api/agent-configuration", + verify=False, + timeout=SHORT_REQUEST_TIMEOUT, + ) + response.raise_for_status() - config_dict = json.loads(response.text) + config_dict = json.loads(response.text) - logger.debug(f"Received configuration:\n{pformat(config_dict)}") + logger.debug(f"Received configuration:\n{pformat(config_dict)}") - return AgentConfiguration(**config_dict) - except ( - requests.exceptions.ConnectionError, - requests.exceptions.TooManyRedirects, - ) as e: - raise IslandAPIConnectionError(e) - except requests.exceptions.Timeout as e: - raise IslandAPITimeoutError(e) - except requests.exceptions.HTTPError as e: - if e.errno >= 500: - raise IslandAPIRequestFailedError(e) - else: - raise IslandAPIRequestError(e) - except json.JSONDecodeError as e: - raise IslandAPIRequestFailedError(e) + return AgentConfiguration(**config_dict) + @handle_island_errors + @convert_json_error_to_island_api_error def get_credentials_for_propagation(self, island_server: str) -> Sequence[Credentials]: - try: - response = requests.get( # noqa: DUO123 - f"https://{island_server}/api/propagation-credentials", - verify=False, - timeout=SHORT_REQUEST_TIMEOUT, - ) - response.raise_for_status() + response = requests.get( # noqa: DUO123 + f"https://{island_server}/api/propagation-credentials", + verify=False, + timeout=SHORT_REQUEST_TIMEOUT, + ) + response.raise_for_status() - return [Credentials(**credentials) for credentials in response.json()] - except ( - requests.exceptions.ConnectionError, - requests.exceptions.TooManyRedirects, - ) as e: - raise IslandAPIConnectionError(e) - except requests.exceptions.Timeout as e: - raise IslandAPITimeoutError(e) - except requests.exceptions.HTTPError as e: - if e.errno >= 500: - raise IslandAPIRequestFailedError(e) - else: - raise IslandAPIRequestError(e) - except json.JSONDecodeError as e: - raise IslandAPIRequestFailedError(e) + return [Credentials(**credentials) for credentials in response.json()] def _serialize_events(self, events: Sequence[AbstractAgentEvent]) -> JSONSerializable: serialized_events: List[JSONSerializable] = []