diff --git a/monkey/infection_monkey/island_api_client/__init__.py b/monkey/infection_monkey/island_api_client/__init__.py index 0dd8a7865..0eb69612e 100644 --- a/monkey/infection_monkey/island_api_client/__init__.py +++ b/monkey/infection_monkey/island_api_client/__init__.py @@ -7,4 +7,4 @@ from .island_api_client_errors import ( ) from .i_island_api_client import IIslandAPIClient from .abstract_island_api_client_factory import AbstractIslandAPIClientFactory -from .http_island_api_client import HTTPIslandAPIClient +from .http_island_api_client import HTTPIslandAPIClient, HTTPIslandAPIClientFactory diff --git a/monkey/infection_monkey/island_api_client/http_island_api_client.py b/monkey/infection_monkey/island_api_client/http_island_api_client.py index 7118e3e37..e3f23ff81 100644 --- a/monkey/infection_monkey/island_api_client/http_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/http_island_api_client.py @@ -4,11 +4,12 @@ from typing import List, Sequence import requests +from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable from common.agent_events import AbstractAgentEvent from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT -from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable from . import ( + AbstractIslandAPIClientFactory, IIslandAPIClient, IslandAPIConnectionError, IslandAPIError, @@ -49,11 +50,16 @@ class HTTPIslandAPIClient(IIslandAPIClient): A client for the Island's HTTP API """ - @handle_island_errors def __init__( + self, + agent_event_serializer_registry: AgentEventSerializerRegistry, + ): + self._agent_event_serializer_registry = agent_event_serializer_registry + + @handle_island_errors + def connect( self, island_server: str, - agent_event_serializer_registry: AgentEventSerializerRegistry = None, ): response = requests.get( # noqa: DUO123 f"https://{island_server}/api?action=is-up", @@ -65,8 +71,6 @@ class HTTPIslandAPIClient(IIslandAPIClient): self._island_server = island_server self._api_url = f"https://{self._island_server}/api" - self._agent_event_serializer_registry = agent_event_serializer_registry - @handle_island_errors def send_log(self, log_contents: str): response = requests.post( # noqa: DUO123 @@ -110,3 +114,14 @@ class HTTPIslandAPIClient(IIslandAPIClient): raise IslandAPIRequestError(err) return serialized_events + + +class HTTPIslandAPIClientFactory(AbstractIslandAPIClientFactory): + def __init__( + self, + agent_event_serializer_registry: AgentEventSerializerRegistry = None, + ): + self._agent_event_serializer_registry = agent_event_serializer_registry + + def create_island_api_client(self): + return HTTPIslandAPIClient(self._agent_event_serializer_registry) diff --git a/monkey/infection_monkey/island_api_client/i_island_api_client.py b/monkey/infection_monkey/island_api_client/i_island_api_client.py index d75f2dea5..61fc98c58 100644 --- a/monkey/infection_monkey/island_api_client/i_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/i_island_api_client.py @@ -10,9 +10,9 @@ class IIslandAPIClient(ABC): """ @abstractmethod - def __init__(self, island_server: str): + def connect(self, island_server: str): """ - Construct an island API client and connect it to the island + Connectto the island's API :param island_server: The socket address of the API :raises IslandAPIConnectionError: If the client cannot successfully connect to the island diff --git a/monkey/infection_monkey/monkey.py b/monkey/infection_monkey/monkey.py index ab3461c73..49952ceca 100644 --- a/monkey/infection_monkey/monkey.py +++ b/monkey/infection_monkey/monkey.py @@ -45,7 +45,7 @@ from infection_monkey.exploit.sshexec import SSHExploiter from infection_monkey.exploit.wmiexec import WmiExploiter from infection_monkey.exploit.zerologon import ZerologonExploiter from infection_monkey.i_puppet import IPuppet, PluginType -from infection_monkey.island_api_client import IIslandAPIClient +from infection_monkey.island_api_client import HTTPIslandAPIClientFactory, IIslandAPIClient from infection_monkey.master import AutomatedMaster from infection_monkey.master.control_channel import ControlChannel from infection_monkey.model import VictimHostFactory @@ -114,12 +114,12 @@ class InfectionMonkey: self._agent_event_serializer_registry = self._setup_agent_event_serializers() # TODO: Revisit variable names - server, island_api_client = self._connect_to_island_api() + server, self._island_api_client = self._connect_to_island_api() # TODO: `address_to_port()` should return the port as an integer. self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(server) self._cmd_island_port = int(self._cmd_island_port) self._control_client = ControlClient( - server_address=server, island_api_client=island_api_client + server_address=server, island_api_client=self._island_api_client ) # TODO Refactor the telemetry messengers to accept control client @@ -145,7 +145,9 @@ class InfectionMonkey: # TODO: By the time we finish 2292, _connect_to_island_api() may not need to return `server` def _connect_to_island_api(self) -> Tuple[str, IIslandAPIClient]: logger.debug(f"Trying to wake up with servers: {', '.join(self._opts.servers)}") - server_clients = find_available_island_apis(self._opts.servers) + server_clients = find_available_island_apis( + self._opts.servers, HTTPIslandAPIClientFactory(self._agent_event_serializer_registry) + ) server, island_api_client = self._select_server(server_clients) @@ -289,7 +291,7 @@ class InfectionMonkey: ), ) event_queue.subscribe_all_events( - AgentEventForwarder(self.island_api_client, agent_event_serializer_registry).send_event + AgentEventForwarder(self._island_api_client, agent_event_serializer_registry).send_event ) def _build_puppet( diff --git a/monkey/infection_monkey/network/relay/utils.py b/monkey/infection_monkey/network/relay/utils.py index 451fd65e7..92c2fb767 100644 --- a/monkey/infection_monkey/network/relay/utils.py +++ b/monkey/infection_monkey/network/relay/utils.py @@ -7,7 +7,7 @@ from typing import Dict, Iterable, Iterator, Mapping, MutableMapping, Optional, from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT from common.network.network_utils import address_to_ip_port from infection_monkey.island_api_client import ( - HTTPIslandAPIClient, + AbstractIslandAPIClientFactory, IIslandAPIClient, IslandAPIConnectionError, IslandAPIError, @@ -27,7 +27,9 @@ logger = logging.getLogger(__name__) NUM_FIND_SERVER_WORKERS = 32 -def find_available_island_apis(servers: Iterable[str]) -> Mapping[str, Optional[IIslandAPIClient]]: +def find_available_island_apis( + servers: Iterable[str], island_api_client_factory: AbstractIslandAPIClientFactory +) -> Mapping[str, Optional[IIslandAPIClient]]: server_list = list(servers) server_iterator = ThreadSafeIterator(server_list.__iter__()) server_results: Dict[str, Tuple[bool, IIslandAPIClient]] = {} @@ -35,7 +37,7 @@ def find_available_island_apis(servers: Iterable[str]) -> Mapping[str, Optional[ run_worker_threads( _find_island_server, "FindIslandServer", - args=(server_iterator, server_results), + args=(server_iterator, server_results, island_api_client_factory), num_workers=NUM_FIND_SERVER_WORKERS, ) @@ -43,18 +45,25 @@ def find_available_island_apis(servers: Iterable[str]) -> Mapping[str, Optional[ def _find_island_server( - servers: Iterator[str], server_status: MutableMapping[str, Optional[IIslandAPIClient]] + servers: Iterator[str], + server_status: MutableMapping[str, Optional[IIslandAPIClient]], + island_api_client_factory: AbstractIslandAPIClientFactory, ): with suppress(StopIteration): server = next(servers) - server_status[server] = _check_if_island_server(server) + server_status[server] = _check_if_island_server(server, island_api_client_factory) -def _check_if_island_server(server: str) -> IIslandAPIClient: +def _check_if_island_server( + server: str, island_api_client_factory: AbstractIslandAPIClientFactory +) -> IIslandAPIClient: logger.debug(f"Trying to connect to server: {server}") try: - return HTTPIslandAPIClient(server) + client = island_api_client_factory.create_island_api_client() + client.connect(server) + + return client except IslandAPIConnectionError as err: logger.error(f"Unable to connect to server/relay {server}: {err}") except IslandAPITimeoutError as err: diff --git a/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py b/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py index 607b18e7b..530713aac 100644 --- a/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py +++ b/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py @@ -1,13 +1,14 @@ -import pytest from uuid import UUID + +import pytest import requests import requests_mock -from common.agent_events import AbstractAgentEvent + from common.agent_event_serializers import ( AgentEventSerializerRegistry, PydanticAgentEventSerializer, ) - +from common.agent_events import AbstractAgentEvent from infection_monkey.island_api_client import ( HTTPIslandAPIClient, IslandAPIConnectionError, @@ -28,6 +29,32 @@ ISLAND_SEND_EVENTS_URI = f"https://{SERVER}/api/agent-events" AGENT_ID = UUID("80988359-a1cd-42a2-9b47-5b94b37cd673") +class Event1(AbstractAgentEvent): + a: int + + +class Event2(AbstractAgentEvent): + b: str + + +class Event3(AbstractAgentEvent): + c: int + + +@pytest.fixture +def agent_event_serializer_registry(): + agent_event_serializer_registry = AgentEventSerializerRegistry() + agent_event_serializer_registry[Event1] = PydanticAgentEventSerializer(Event1) + agent_event_serializer_registry[Event2] = PydanticAgentEventSerializer(Event2) + + return agent_event_serializer_registry + + +@pytest.fixture +def island_api_client(agent_event_serializer_registry): + return HTTPIslandAPIClient(agent_event_serializer_registry) + + @pytest.mark.parametrize( "actual_error, expected_error", [ @@ -36,12 +63,12 @@ AGENT_ID = UUID("80988359-a1cd-42a2-9b47-5b94b37cd673") (Exception, IslandAPIError), ], ) -def test_island_api_client(actual_error, expected_error): +def test_island_api_client(island_api_client, actual_error, expected_error): with requests_mock.Mocker() as m: m.get(ISLAND_URI, exc=actual_error) with pytest.raises(expected_error): - HTTPIslandAPIClient(SERVER) + island_api_client.connect(SERVER) @pytest.mark.parametrize( @@ -51,12 +78,12 @@ def test_island_api_client(actual_error, expected_error): (501, IslandAPIRequestFailedError), ], ) -def test_island_api_client__status_code(status_code, expected_error): +def test_island_api_client__status_code(island_api_client, status_code, expected_error): with requests_mock.Mocker() as m: m.get(ISLAND_URI, status_code=status_code) with pytest.raises(expected_error): - HTTPIslandAPIClient(SERVER) + island_api_client.connect(SERVER) @pytest.mark.parametrize( @@ -67,10 +94,10 @@ def test_island_api_client__status_code(status_code, expected_error): (Exception, IslandAPIError), ], ) -def test_island_api_client__send_log(actual_error, expected_error): +def test_island_api_client__send_log(island_api_client, actual_error, expected_error): with requests_mock.Mocker() as m: m.get(ISLAND_URI) - island_api_client = HTTPIslandAPIClient(SERVER) + island_api_client.connect(SERVER) with pytest.raises(expected_error): m.post(ISLAND_SEND_LOG_URI, exc=actual_error) @@ -84,10 +111,10 @@ def test_island_api_client__send_log(actual_error, expected_error): (501, IslandAPIRequestFailedError), ], ) -def test_island_api_client_send_log__status_code(status_code, expected_error): +def test_island_api_client_send_log__status_code(island_api_client, status_code, expected_error): with requests_mock.Mocker() as m: m.get(ISLAND_URI) - island_api_client = HTTPIslandAPIClient(SERVER) + island_api_client.connect(SERVER) with pytest.raises(expected_error): m.post(ISLAND_SEND_LOG_URI, status_code=status_code) @@ -102,10 +129,10 @@ def test_island_api_client_send_log__status_code(status_code, expected_error): (Exception, IslandAPIError), ], ) -def test_island_api_client__get_pba_file(actual_error, expected_error): +def test_island_api_client__get_pba_file(island_api_client, actual_error, expected_error): with requests_mock.Mocker() as m: m.get(ISLAND_URI) - island_api_client = HTTPIslandAPIClient(SERVER) + island_api_client.connect(SERVER) with pytest.raises(expected_error): m.get(ISLAND_GET_PBA_FILE_URI, exc=actual_error) @@ -119,37 +146,19 @@ def test_island_api_client__get_pba_file(actual_error, expected_error): (501, IslandAPIRequestFailedError), ], ) -def test_island_api_client_get_pba_file__status_code(status_code, expected_error): +def test_island_api_client_get_pba_file__status_code( + island_api_client, status_code, expected_error +): with requests_mock.Mocker() as m: m.get(ISLAND_URI) - island_api_client = HTTPIslandAPIClient(SERVER) + island_api_client.connect(SERVER) with pytest.raises(expected_error): m.get(ISLAND_GET_PBA_FILE_URI, status_code=status_code) island_api_client.get_pba_file(filename=PBA_FILE) -class Event1(AbstractAgentEvent): - a: int - - -class Event2(AbstractAgentEvent): - b: str - - -class Event3(AbstractAgentEvent): - c: int - -@pytest.fixture -def agent_event_serializer_registry(): - agent_event_serializer_registry = AgentEventSerializerRegistry() - agent_event_serializer_registry[Event1] = PydanticAgentEventSerializer(Event1) - agent_event_serializer_registry[Event2] = PydanticAgentEventSerializer(Event2) - - return agent_event_serializer_registry - - -def test_island_api_client_send_events__serialization(agent_event_serializer_registry): +def test_island_api_client_send_events__serialization(island_api_client): events_to_send = [ Event1(source=AGENT_ID, timestamp=0, a=1), Event2(source=AGENT_ID, timestamp=0, b="hello"), @@ -176,16 +185,17 @@ def test_island_api_client_send_events__serialization(agent_event_serializer_reg with requests_mock.Mocker() as m: m.get(ISLAND_URI) m.post(ISLAND_SEND_EVENTS_URI) - island_api_client = HTTPIslandAPIClient(SERVER, agent_event_serializer_registry) + island_api_client.connect(SERVER) island_api_client.send_events(events=events_to_send) assert m.last_request.json() == expected_json -def test_island_api_client_send_events__serialization_failed(agent_event_serializer_registry): + +def test_island_api_client_send_events__serialization_failed(island_api_client): with requests_mock.Mocker() as m: m.get(ISLAND_URI) - island_api_client = HTTPIslandAPIClient(SERVER, agent_event_serializer_registry) + island_api_client.connect(SERVER) with pytest.raises(IslandAPIRequestError): m.post(ISLAND_SEND_EVENTS_URI) @@ -200,12 +210,10 @@ def test_island_api_client_send_events__serialization_failed(agent_event_seriali (Exception, IslandAPIError), ], ) -def test_island_api_client__send_events( - actual_error, expected_error, agent_event_serializer_registry -): +def test_island_api_client__send_events(island_api_client, actual_error, expected_error): with requests_mock.Mocker() as m: m.get(ISLAND_URI) - island_api_client = HTTPIslandAPIClient(SERVER, agent_event_serializer_registry) + island_api_client.connect(SERVER) with pytest.raises(expected_error): m.post(ISLAND_SEND_EVENTS_URI, exc=actual_error) @@ -219,12 +227,10 @@ def test_island_api_client__send_events( (501, IslandAPIRequestFailedError), ], ) -def test_island_api_client_send_events__status_code( - status_code, expected_error, agent_event_serializer_registry -): +def test_island_api_client_send_events__status_code(island_api_client, status_code, expected_error): with requests_mock.Mocker() as m: m.get(ISLAND_URI) - island_api_client = HTTPIslandAPIClient(SERVER, agent_event_serializer_registry) + island_api_client.connect(SERVER) with pytest.raises(expected_error): m.post(ISLAND_SEND_EVENTS_URI, status_code=status_code) diff --git a/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py b/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py index dd00ebc94..affb61aff 100644 --- a/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py +++ b/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py @@ -1,7 +1,12 @@ import pytest import requests_mock -from infection_monkey.island_api_client import IIslandAPIClient, IslandAPIConnectionError +from common.agent_event_serializers import AgentEventSerializerRegistry +from infection_monkey.island_api_client import ( + HTTPIslandAPIClientFactory, + IIslandAPIClient, + IslandAPIConnectionError, +) from infection_monkey.network.relay.utils import find_available_island_apis SERVER_1 = "1.1.1.1:12312" @@ -13,6 +18,11 @@ SERVER_4 = "4.4.4.4:5000" servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4] +@pytest.fixture +def island_api_client_factory(): + return HTTPIslandAPIClientFactory(AgentEventSerializerRegistry()) + + @pytest.mark.parametrize( "expected_available_servers, server_response_pairs", [ @@ -24,12 +34,14 @@ servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4] ), ], ) -def test_find_available_island_apis(expected_available_servers, server_response_pairs): +def test_find_available_island_apis( + expected_available_servers, server_response_pairs, island_api_client_factory +): with requests_mock.Mocker() as mock: for server, response in server_response_pairs: mock.get(f"https://{server}/api?action=is-up", **response) - available_apis = find_available_island_apis(servers) + available_apis = find_available_island_apis(servers, island_api_client_factory) assert len(available_apis) == len(server_response_pairs) @@ -40,14 +52,14 @@ def test_find_available_island_apis(expected_available_servers, server_response_ assert island_api_client is None -def test_find_available_island_apis__multiple_successes(): +def test_find_available_island_apis__multiple_successes(island_api_client_factory): available_servers = [SERVER_2, SERVER_3] with requests_mock.Mocker() as mock: mock.get(f"https://{SERVER_1}/api?action=is-up", exc=IslandAPIConnectionError) for server in available_servers: mock.get(f"https://{server}/api?action=is-up", text="") - available_apis = find_available_island_apis(servers) + available_apis = find_available_island_apis(servers, island_api_client_factory) assert available_apis[SERVER_1] is None assert available_apis[SERVER_4] is None