diff --git a/monkey/infection_monkey/control.py b/monkey/infection_monkey/control.py index 58ab43fa6..440f58712 100644 --- a/monkey/infection_monkey/control.py +++ b/monkey/infection_monkey/control.py @@ -6,9 +6,10 @@ from socket import gethostname import requests from urllib3 import disable_warnings -from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT +from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT from common.network.network_utils import get_my_ip_addresses from infection_monkey.config import GUID +from infection_monkey.island_api_client import IIslandAPIClient from infection_monkey.network.info import get_host_subnets from infection_monkey.utils import agent_process @@ -16,8 +17,6 @@ disable_warnings() # noqa DUO131 logger = logging.getLogger(__name__) -PBA_FILE_DOWNLOAD = "https://%s/api/pba/download/%s" - class ControlClient: # TODO When we have mechanism that support telemetry messenger @@ -25,8 +24,9 @@ class ControlClient: # https://github.com/guardicore/monkey/blob/133f7f5da131b481561141171827d1f9943f6aec/monkey/infection_monkey/telemetry/base_telem.py control_client_object = None - def __init__(self, server_address: str): + def __init__(self, server_address: str, island_api_client: IIslandAPIClient): self.server_address = server_address + self._island_api_client = island_api_client def wakeup(self, parent=None): if parent: @@ -78,22 +78,12 @@ class ControlClient: return try: telemetry = {"monkey_guid": GUID, "log": json.dumps(log)} - requests.post( # noqa: DUO123 - "https://%s/api/log" % (self.server_address,), - data=json.dumps(telemetry), - headers={"content-type": "application/json"}, - verify=False, - timeout=MEDIUM_REQUEST_TIMEOUT, - ) + self._island_api_client.send_log(json.dumps(telemetry)) except Exception as exc: logger.warning(f"Error connecting to control server {self.server_address}: {exc}") def get_pba_file(self, filename): try: - return requests.get( # noqa: DUO123 - PBA_FILE_DOWNLOAD % (self.server_address, filename), - verify=False, - timeout=LONG_REQUEST_TIMEOUT, - ) - except requests.exceptions.RequestException: - return False + return self._island_api_client.get_pba_file(filename) + except Exception as exc: + logger.warning(f"Error connecting to control server {self.server_address}: {exc}") diff --git a/monkey/infection_monkey/island_api_client/http_island_api_client.py b/monkey/infection_monkey/island_api_client/http_island_api_client.py index d9f9b1d9e..37feb8942 100644 --- a/monkey/infection_monkey/island_api_client/http_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/http_island_api_client.py @@ -1,29 +1,78 @@ +import functools import logging import requests -from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT +from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT -from . import IIslandAPIClient, IslandAPIConnectionError, IslandAPIError, IslandAPITimeoutError +from . import ( + IIslandAPIClient, + IslandAPIConnectionError, + IslandAPIError, + IslandAPIRequestError, + IslandAPIRequestFailedError, + IslandAPITimeoutError, +) logger = logging.getLogger(__name__) +def handle_island_errors(fn): + @functools.wraps(fn) + def decorated(*args, **kwargs): + try: + return fn(*args, **kwargs) + except requests.exceptions.ConnectionError as err: + raise IslandAPIConnectionError(err) + except requests.exceptions.HTTPError as err: + if 400 <= err.response.status_code < 500: + raise IslandAPIRequestError(err) + elif 500 <= err.response.status_code < 600: + raise IslandAPIRequestFailedError(err) + else: + raise IslandAPIError(err) + except TimeoutError as err: + raise IslandAPITimeoutError(err) + except Exception as err: + raise IslandAPIError(err) + + return decorated + + class HTTPIslandAPIClient(IIslandAPIClient): """ A client for the Island's HTTP API """ + @handle_island_errors def __init__(self, island_server: str): - try: - requests.get( # noqa: DUO123 - f"https://{island_server}/api?action=is-up", - verify=False, - timeout=MEDIUM_REQUEST_TIMEOUT, - ) - except requests.exceptions.ConnectionError as err: - raise IslandAPIConnectionError(err) - except TimeoutError as err: - raise IslandAPITimeoutError(err) - except Exception as err: - raise IslandAPIError(err) + response = requests.get( # noqa: DUO123 + f"https://{island_server}/api?action=is-up", + verify=False, + timeout=MEDIUM_REQUEST_TIMEOUT, + ) + response.raise_for_status() + + self._island_server = island_server + self._api_url = f"https://{self._island_server}/api" + + @handle_island_errors + def send_log(self, log_contents: str): + response = requests.post( # noqa: DUO123 + f"{self._api_url}/log", + json=log_contents, + verify=False, + timeout=MEDIUM_REQUEST_TIMEOUT, + ) + response.raise_for_status() + + @handle_island_errors + def get_pba_file(self, filename: str): + response = requests.get( # noqa: DUO123 + f"{self._api_url}/pba/download/{filename}", + verify=False, + timeout=LONG_REQUEST_TIMEOUT, + ) + response.raise_for_status() + + return response.content diff --git a/monkey/infection_monkey/island_api_client/i_island_api_client.py b/monkey/infection_monkey/island_api_client/i_island_api_client.py index 4a168b3d5..fefba9973 100644 --- a/monkey/infection_monkey/island_api_client/i_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/i_island_api_client.py @@ -9,11 +9,48 @@ class IIslandAPIClient(ABC): @abstractmethod def __init__(self, island_server: str): """ - Construct and island API client and connect it to the island + Construct an island API client and connect it to the island :param island_server: The socket address of the API :raises IslandAPIConnectionError: If the client cannot successfully connect to the island + :raises IslandAPIRequestError: If an error occurs while attempting to connect to the + island due to an issue in the request sent from the client + :raises IslandAPIRequestFailedError: If an error occurs while attempting to connect to the + island due to an error on the server :raises IslandAPITimeoutError: If a timeout occurs while attempting to connect to the island :raises IslandAPIError: If an unexpected error occurs while attempting to connect to the island """ + + @abstractmethod + def send_log(self, log_contents: str): + """ + Send the contents of the agent's log to the island + + :param log_contents: The contents of the agent's log + :raises IslandAPIConnectionError: If the client cannot successfully connect to the island + :raises IslandAPIRequestError: If an error occurs while attempting to connect to the + island due to an issue in the request sent from the client + :raises IslandAPIRequestFailedError: If an error occurs while attempting to connect to the + island due to an error on the server + :raises IslandAPITimeoutError: If a timeout occurs while attempting to connect to the island + :raises IslandAPIError: If an unexpected error occurs while attempting to send the + contents of the agent's log to the island + """ + + @abstractmethod + def get_pba_file(self, filename: str) -> bytes: + """ + Get a custom PBA file from the island + + :param filename: The name of the custom PBA file + :return: The contents of the custom PBA file in bytes + :raises IslandAPIConnectionError: If the client cannot successfully connect to the island + :raises IslandAPIRequestError: If an error occurs while attempting to connect to the + island due to an issue in the request sent from the client + :raises IslandAPIRequestFailedError: If an error occurs while attempting to connect to the + island due to an error on the server + :raises IslandAPITimeoutError: If a timeout occurs while attempting to connect to the island + :raises IslandAPIError: If an unexpected error occurs while attempting to retrieve the + custom PBA file + """ diff --git a/monkey/infection_monkey/monkey.py b/monkey/infection_monkey/monkey.py index 680db9f6f..96da82225 100644 --- a/monkey/infection_monkey/monkey.py +++ b/monkey/infection_monkey/monkey.py @@ -115,7 +115,9 @@ class InfectionMonkey: # TODO: `address_to_port()` should return the port as an integer. self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(server) self._cmd_island_port = int(self._cmd_island_port) - self._control_client = ControlClient(server_address=server) + self._control_client = ControlClient( + server_address=server, island_api_client=island_api_client + ) # TODO Refactor the telemetry messengers to accept control client # and remove control_client_object diff --git a/monkey/infection_monkey/post_breach/custom_pba/custom_pba.py b/monkey/infection_monkey/post_breach/custom_pba/custom_pba.py index f5087a91a..34fb73147 100644 --- a/monkey/infection_monkey/post_breach/custom_pba/custom_pba.py +++ b/monkey/infection_monkey/post_breach/custom_pba/custom_pba.py @@ -76,7 +76,7 @@ class CustomPBA(PBA): pba_file_contents = self.control_client.get_pba_file(filename) status = None - if not pba_file_contents or not pba_file_contents.content: + if not pba_file_contents: logger.error("Island didn't respond with post breach file.") status = ScanStatus.SCANNED @@ -97,7 +97,7 @@ class CustomPBA(PBA): try: with open(os.path.join(dst_dir, filename), "wb") as written_PBA_file: - written_PBA_file.write(pba_file_contents.content) + written_PBA_file.write(pba_file_contents) return True except IOError as e: logger.error("Can not upload post breach file to target machine: %s" % e) diff --git a/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py b/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py index f213b0569..bd6bfcb41 100644 --- a/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py +++ b/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py @@ -6,12 +6,17 @@ from infection_monkey.island_api_client import ( HTTPIslandAPIClient, IslandAPIConnectionError, IslandAPIError, + IslandAPIRequestError, + IslandAPIRequestFailedError, IslandAPITimeoutError, ) SERVER = "1.1.1.1:9999" +PBA_FILE = "dummy.pba" ISLAND_URI = f"https://{SERVER}/api?action=is-up" +ISLAND_SEND_LOG_URI = f"https://{SERVER}/api/log" +ISLAND_GET_PBA_FILE_URI = f"https://{SERVER}/api/pba/download/{PBA_FILE}" @pytest.mark.parametrize( @@ -28,3 +33,88 @@ def test_island_api_client(actual_error, expected_error): with pytest.raises(expected_error): HTTPIslandAPIClient(SERVER) + + +@pytest.mark.parametrize( + "status_code, expected_error", + [ + (401, IslandAPIRequestError), + (501, IslandAPIRequestFailedError), + ], +) +def test_island_api_client__status_code(status_code, expected_error): + with requests_mock.Mocker() as m: + m.get(ISLAND_URI, status_code=status_code) + + with pytest.raises(expected_error): + HTTPIslandAPIClient(SERVER) + + +@pytest.mark.parametrize( + "actual_error, expected_error", + [ + (requests.exceptions.ConnectionError, IslandAPIConnectionError), + (TimeoutError, IslandAPITimeoutError), + (Exception, IslandAPIError), + ], +) +def test_island_api_client__send_log(actual_error, expected_error): + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + island_api_client = HTTPIslandAPIClient(SERVER) + + with pytest.raises(expected_error): + m.post(ISLAND_SEND_LOG_URI, exc=actual_error) + island_api_client.send_log(log_contents="some_data") + + +@pytest.mark.parametrize( + "status_code, expected_error", + [ + (401, IslandAPIRequestError), + (501, IslandAPIRequestFailedError), + ], +) +def test_island_api_client_send_log__status_code(status_code, expected_error): + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + island_api_client = HTTPIslandAPIClient(SERVER) + + with pytest.raises(expected_error): + m.post(ISLAND_SEND_LOG_URI, status_code=status_code) + island_api_client.send_log(log_contents="some_data") + + +@pytest.mark.parametrize( + "actual_error, expected_error", + [ + (requests.exceptions.ConnectionError, IslandAPIConnectionError), + (TimeoutError, IslandAPITimeoutError), + (Exception, IslandAPIError), + ], +) +def test_island_api_client__get_pba_file(actual_error, expected_error): + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + island_api_client = HTTPIslandAPIClient(SERVER) + + with pytest.raises(expected_error): + m.get(ISLAND_GET_PBA_FILE_URI, exc=actual_error) + island_api_client.get_pba_file(filename=PBA_FILE) + + +@pytest.mark.parametrize( + "status_code, expected_error", + [ + (401, IslandAPIRequestError), + (501, IslandAPIRequestFailedError), + ], +) +def test_island_api_client_get_pba_file__status_code(status_code, expected_error): + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + island_api_client = HTTPIslandAPIClient(SERVER) + + with pytest.raises(expected_error): + m.get(ISLAND_GET_PBA_FILE_URI, status_code=status_code) + island_api_client.get_pba_file(filename=PBA_FILE) diff --git a/vulture_allowlist.py b/vulture_allowlist.py index 83fb82005..53001be59 100644 --- a/vulture_allowlist.py +++ b/vulture_allowlist.py @@ -9,12 +9,6 @@ from common.agent_configuration.agent_sub_configurations import ( ) from common.credentials import Credentials, LMHash, NTHash from infection_monkey.exploit.log4shell_utils.ldap_server import LDAPServerFactory -from infection_monkey.island_api_client import ( - HTTPIslandAPIClient, - IIslandAPIClient, - IslandAPIRequestError, - IslandAPIRequestFailedError, -) from monkey_island.cc.event_queue import IslandEventTopic, PyPubSubIslandEventQueue from monkey_island.cc.models import Report from monkey_island.cc.models.networkmap import Arc, NetworkMap @@ -334,9 +328,3 @@ CC_TUNNEL IslandEventTopic.AGENT_CONNECTED IslandEventTopic.CLEAR_SIMULATION_DATA IslandEventTopic.RESET_AGENT_CONFIGURATION - -# TODO: Remove after #2292 is closed -IIslandAPIClient -HTTPIslandAPIClient -IslandAPIRequestFailedError -IslandAPIRequestError