diff --git a/monkey/infection_monkey/exploit/__init__.py b/monkey/infection_monkey/exploit/__init__.py index 195e880ad..22d1fc865 100644 --- a/monkey/infection_monkey/exploit/__init__.py +++ b/monkey/infection_monkey/exploit/__init__.py @@ -1,3 +1,3 @@ -from .i_agent_binary_repository import IAgentBinaryRepository +from .i_agent_binary_repository import IAgentBinaryRepository, RetrievalError from .caching_agent_binary_repository import CachingAgentBinaryRepository from .exploiter_wrapper import ExploiterWrapper diff --git a/monkey/infection_monkey/exploit/caching_agent_binary_repository.py b/monkey/infection_monkey/exploit/caching_agent_binary_repository.py index 745aae112..1d8067da3 100644 --- a/monkey/infection_monkey/exploit/caching_agent_binary_repository.py +++ b/monkey/infection_monkey/exploit/caching_agent_binary_repository.py @@ -1,13 +1,14 @@ import io +import logging import threading from functools import lru_cache -import requests - from common import OperatingSystem -from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT +from infection_monkey.island_api_client import IIslandAPIClient, IslandAPIError -from . import IAgentBinaryRepository +from . import IAgentBinaryRepository, RetrievalError + +logger = logging.getLogger(__name__) class CachingAgentBinaryRepository(IAgentBinaryRepository): @@ -17,9 +18,9 @@ class CachingAgentBinaryRepository(IAgentBinaryRepository): request is actually sent to the island for each requested binary. """ - def __init__(self, island_url: str): - self._island_url = island_url + def __init__(self, island_api_client: IIslandAPIClient): self._lock = threading.Lock() + self._island_api_client = island_api_client def get_agent_binary( self, operating_system: OperatingSystem, architecture: str = None @@ -33,14 +34,7 @@ class CachingAgentBinaryRepository(IAgentBinaryRepository): @lru_cache(maxsize=None) def _download_binary_from_island(self, operating_system: OperatingSystem) -> bytes: - os_name = operating_system.value - - response = requests.get( # noqa: DUO123 - f"{self._island_url}/api/agent-binaries/{os_name}", - verify=False, - timeout=MEDIUM_REQUEST_TIMEOUT, - ) - - response.raise_for_status() - - return response.content + try: + return self._island_api_client.get_agent_binary(operating_system) + except IslandAPIError as err: + raise RetrievalError(err) diff --git a/monkey/infection_monkey/exploit/i_agent_binary_repository.py b/monkey/infection_monkey/exploit/i_agent_binary_repository.py index 09de7a696..b78888864 100644 --- a/monkey/infection_monkey/exploit/i_agent_binary_repository.py +++ b/monkey/infection_monkey/exploit/i_agent_binary_repository.py @@ -7,6 +7,12 @@ from common import OperatingSystem # moment, the Island and Agent have different needs, but at some point we should unify these. +class RetrievalError(RuntimeError): + """ + Raised when a repository encounters an error while attempting to retrieve data + """ + + class IAgentBinaryRepository(metaclass=abc.ABCMeta): """ IAgentBinaryRepository provides an interface for other components to access agent binaries. @@ -23,5 +29,6 @@ class IAgentBinaryRepository(metaclass=abc.ABCMeta): :param operating_system: The name of the operating system on which the agent binary will run :param architecture: Reserved :return: A file-like object for the requested agent binary + :raises RetrievalError: If an error occurs when retrieving the agent binary """ pass diff --git a/monkey/infection_monkey/island_api_client/http_island_api_client.py b/monkey/infection_monkey/island_api_client/http_island_api_client.py index 37feb8942..3b2adfe61 100644 --- a/monkey/infection_monkey/island_api_client/http_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/http_island_api_client.py @@ -3,6 +3,7 @@ import logging import requests +from common import OperatingSystem from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT from . import ( @@ -76,3 +77,15 @@ class HTTPIslandAPIClient(IIslandAPIClient): response.raise_for_status() return response.content + + @handle_island_errors + def get_agent_binary(self, operating_system: OperatingSystem): + os_name = operating_system.value + response = requests.get( # noqa: DUO123 + f"{self._api_url}/agent-binaries/{os_name}", + verify=False, + timeout=MEDIUM_REQUEST_TIMEOUT, + ) + response.raise_for_status() + + return response.content diff --git a/monkey/infection_monkey/island_api_client/i_island_api_client.py b/monkey/infection_monkey/island_api_client/i_island_api_client.py index fefba9973..fea93d1dc 100644 --- a/monkey/infection_monkey/island_api_client/i_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/i_island_api_client.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import Optional class IIslandAPIClient(ABC): @@ -54,3 +55,20 @@ class IIslandAPIClient(ABC): :raises IslandAPIError: If an unexpected error occurs while attempting to retrieve the custom PBA file """ + + @abstractmethod + def get_agent_binary(self, os_name: str) -> Optional[bytes]: + """ + Get an agent binary for the given OS from the island + + :param os_name: The OS on which the agent binary will run + :return: The agent binary file + :raises IslandAPIConnectionError: If the client cannot successfully connect to the island + :raises IslandAPIRequestError: If an error occurs while attempting to connect to the + island due to an issue in the request sent from the client + :raises IslandAPIRequestFailedError: If an error occurs while attempting to connect to the + island due to an error on the server + :raises IslandAPITimeoutError: If a timeout occurs while attempting to connect to the island + :raises IslandAPIError: If an unexpected error occurs while attempting to retrieve the + agent binary + """ diff --git a/monkey/infection_monkey/monkey.py b/monkey/infection_monkey/monkey.py index 96da82225..21b115735 100644 --- a/monkey/infection_monkey/monkey.py +++ b/monkey/infection_monkey/monkey.py @@ -110,13 +110,12 @@ class InfectionMonkey: self._singleton = SystemSingleton() self._opts = self._get_arguments(args) - # TODO: Revisit variable names - server, island_api_client = self._connect_to_island_api() + server, self._island_api_client = self._connect_to_island_api() # TODO: `address_to_port()` should return the port as an integer. self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(server) self._cmd_island_port = int(self._cmd_island_port) self._control_client = ControlClient( - server_address=server, island_api_client=island_api_client + server_address=server, island_api_client=self._island_api_client ) # TODO Refactor the telemetry messengers to accept control client @@ -315,7 +314,7 @@ class InfectionMonkey: puppet.load_plugin("ssh", SSHFingerprinter(), PluginType.FINGERPRINTER) agent_binary_repository = CachingAgentBinaryRepository( - f"https://{self._control_client.server_address}" + island_api_client=self._island_api_client, ) exploit_wrapper = ExploiterWrapper( self._telemetry_messenger, event_queue, agent_binary_repository diff --git a/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py b/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py index bd6bfcb41..22ad87161 100644 --- a/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py +++ b/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py @@ -2,6 +2,7 @@ import pytest import requests import requests_mock +from common import OperatingSystem from infection_monkey.island_api_client import ( HTTPIslandAPIClient, IslandAPIConnectionError, @@ -13,10 +14,12 @@ from infection_monkey.island_api_client import ( SERVER = "1.1.1.1:9999" PBA_FILE = "dummy.pba" +WINDOWS = "windows" ISLAND_URI = f"https://{SERVER}/api?action=is-up" ISLAND_SEND_LOG_URI = f"https://{SERVER}/api/log" ISLAND_GET_PBA_FILE_URI = f"https://{SERVER}/api/pba/download/{PBA_FILE}" +ISLAND_GET_AGENT_BINARY_URI = f"https://{SERVER}/api/agent-binaries/{WINDOWS}" @pytest.mark.parametrize( @@ -118,3 +121,38 @@ def test_island_api_client_get_pba_file__status_code(status_code, expected_error with pytest.raises(expected_error): m.get(ISLAND_GET_PBA_FILE_URI, status_code=status_code) island_api_client.get_pba_file(filename=PBA_FILE) + + +@pytest.mark.parametrize( + "actual_error, expected_error", + [ + (requests.exceptions.ConnectionError, IslandAPIConnectionError), + (TimeoutError, IslandAPITimeoutError), + (Exception, IslandAPIError), + ], +) +def test_island_api_client__get_agent_binary(actual_error, expected_error): + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + island_api_client = HTTPIslandAPIClient(SERVER) + + with pytest.raises(expected_error): + m.get(ISLAND_GET_AGENT_BINARY_URI, exc=actual_error) + island_api_client.get_agent_binary(operating_system=OperatingSystem.WINDOWS) + + +@pytest.mark.parametrize( + "status_code, expected_error", + [ + (401, IslandAPIRequestError), + (501, IslandAPIRequestFailedError), + ], +) +def test_island_api_client__get_agent_binary_status_code(status_code, expected_error): + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + island_api_client = HTTPIslandAPIClient(SERVER) + + with pytest.raises(expected_error): + m.get(ISLAND_GET_AGENT_BINARY_URI, status_code=status_code) + island_api_client.get_agent_binary(operating_system=OperatingSystem.WINDOWS)