diff --git a/monkey/infection_monkey/master/control_channel.py b/monkey/infection_monkey/master/control_channel.py index cce8791e1..6a972b6fa 100644 --- a/monkey/infection_monkey/master/control_channel.py +++ b/monkey/infection_monkey/master/control_channel.py @@ -1,4 +1,5 @@ import logging +from functools import wraps from typing import Optional, Sequence from uuid import UUID @@ -24,12 +25,30 @@ disable_warnings() # noqa: DUO131 logger = logging.getLogger(__name__) +def handle_island_api_errors(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + print(args) + func(*args, **kwargs) + except ( + IslandAPIConnectionError, + IslandAPIRequestError, + IslandAPIRequestFailedError, + IslandAPITimeoutError, + ) as e: + raise IslandCommunicationError(e) + + return wrapper + + class ControlChannel(IControlChannel): def __init__(self, server: str, agent_id: str, api_client: IIslandAPIClient): self._agent_id = agent_id self._control_channel_server = server self._island_api_client = api_client + @handle_island_api_errors def register_agent(self, parent: Optional[UUID] = None): agent_registration_data = AgentRegistrationData( id=get_agent_id(), @@ -41,47 +60,21 @@ class ControlChannel(IControlChannel): network_interfaces=get_network_interfaces(), ) - try: - self._island_api_client.register_agent(agent_registration_data) - except (IslandAPIConnectionError, IslandAPITimeoutError) as e: - raise IslandCommunicationError(e) + self._island_api_client.register_agent(agent_registration_data) + @handle_island_api_errors def should_agent_stop(self) -> bool: if not self._control_channel_server: logger.error("Agent should stop because it can't connect to the C&C server.") return True - try: - return self._island_api_client.should_agent_stop( - self._control_channel_server, self._agent_id - ) - except ( - IslandAPIConnectionError, - IslandAPIRequestError, - IslandAPIRequestFailedError, - IslandAPITimeoutError, - ) as e: - raise IslandCommunicationError(e) + return self._island_api_client.should_agent_stop( + self._control_channel_server, self._agent_id + ) + @handle_island_api_errors def get_config(self) -> AgentConfiguration: - try: - return self._island_api_client.get_config(self._control_channel_server) - except ( - IslandAPIConnectionError, - IslandAPIRequestError, - IslandAPIRequestFailedError, - IslandAPITimeoutError, - ) as e: - raise IslandCommunicationError(e) + return self._island_api_client.get_config(self._control_channel_server) + @handle_island_api_errors def get_credentials_for_propagation(self) -> Sequence[Credentials]: - try: - return self._island_api_client.get_credentials_for_propagation( - self._control_channel_server - ) - except ( - IslandAPIConnectionError, - IslandAPIRequestError, - IslandAPIRequestFailedError, - IslandAPITimeoutError, - ) as e: - raise IslandCommunicationError(e) + return self._island_api_client.get_credentials_for_propagation(self._control_channel_server)