Agent: Use SocketAddress in IIslandAPIClient

This commit is contained in:
Kekoa Kaaikala 2022-09-23 21:28:39 +00:00 committed by Shreya Malviya
parent 90890106f7
commit c4804f06a9
6 changed files with 82 additions and 37 deletions

View File

@ -16,6 +16,7 @@ from common.common_consts.timeouts import (
SHORT_REQUEST_TIMEOUT, SHORT_REQUEST_TIMEOUT,
) )
from common.credentials import Credentials from common.credentials import Credentials
from common.types import SocketAddress
from . import ( from . import (
AbstractIslandAPIClientFactory, AbstractIslandAPIClientFactory,
@ -79,7 +80,7 @@ class HTTPIslandAPIClient(IIslandAPIClient):
@handle_island_errors @handle_island_errors
def connect( def connect(
self, self,
island_server: str, island_server: SocketAddress,
): ):
response = requests.get( # noqa: DUO123 response = requests.get( # noqa: DUO123
f"https://{island_server}/api?action=is-up", f"https://{island_server}/api?action=is-up",
@ -88,8 +89,7 @@ class HTTPIslandAPIClient(IIslandAPIClient):
) )
response.raise_for_status() response.raise_for_status()
self._island_server = island_server self._api_url = f"https://{island_server}/api"
self._api_url = f"https://{self._island_server}/api"
@handle_island_errors @handle_island_errors
def send_log(self, log_contents: str): def send_log(self, log_contents: str):

View File

@ -5,6 +5,7 @@ from common import AgentRegistrationData, AgentSignals, OperatingSystem
from common.agent_configuration import AgentConfiguration from common.agent_configuration import AgentConfiguration
from common.agent_events import AbstractAgentEvent from common.agent_events import AbstractAgentEvent
from common.credentials import Credentials from common.credentials import Credentials
from common.types import SocketAddress
class IIslandAPIClient(ABC): class IIslandAPIClient(ABC):
@ -13,7 +14,7 @@ class IIslandAPIClient(ABC):
""" """
@abstractmethod @abstractmethod
def connect(self, island_server: str): def connect(self, island_server: SocketAddress):
""" """
Connect to the island's API Connect to the island's API

View File

@ -5,7 +5,7 @@ import subprocess
import sys import sys
from ipaddress import IPv4Interface from ipaddress import IPv4Interface
from pathlib import Path, WindowsPath from pathlib import Path, WindowsPath
from typing import List, Mapping, Optional, Tuple from typing import List, Optional, Tuple
from pubsub.core import Publisher from pubsub.core import Publisher
@ -55,6 +55,7 @@ from infection_monkey.network.firewall import app as firewall
from infection_monkey.network.info import get_free_tcp_port from infection_monkey.network.info import get_free_tcp_port
from infection_monkey.network.relay import TCPRelay from infection_monkey.network.relay import TCPRelay
from infection_monkey.network.relay.utils import ( from infection_monkey.network.relay.utils import (
IslandAPISearchResults,
find_available_island_apis, find_available_island_apis,
notify_disconnect, notify_disconnect,
send_remove_from_waitlist_control_message_to_relays, send_remove_from_waitlist_control_message_to_relays,
@ -156,7 +157,7 @@ class InfectionMonkey:
def _connect_to_island_api(self) -> Tuple[Optional[str], Optional[IIslandAPIClient]]: def _connect_to_island_api(self) -> Tuple[Optional[str], Optional[IIslandAPIClient]]:
logger.debug(f"Trying to wake up with servers: {', '.join(self._server_strings)}") logger.debug(f"Trying to wake up with servers: {', '.join(self._server_strings)}")
server_clients = find_available_island_apis( server_clients = find_available_island_apis(
self._server_strings, HTTPIslandAPIClientFactory(self._agent_event_serializer_registry) self._opts.servers, HTTPIslandAPIClientFactory(self._agent_event_serializer_registry)
) )
server, island_api_client = self._select_server(server_clients) server, island_api_client = self._select_server(server_clients)
@ -189,11 +190,11 @@ class InfectionMonkey:
self._island_api_client.register_agent(agent_registration_data) self._island_api_client.register_agent(agent_registration_data)
def _select_server( def _select_server(
self, server_clients: Mapping[str, Optional[IIslandAPIClient]] self, server_clients: IslandAPISearchResults
) -> Tuple[Optional[str], Optional[IIslandAPIClient]]: ) -> Tuple[Optional[SocketAddress], Optional[IIslandAPIClient]]:
for server in self._server_strings: for result in server_clients:
if server_clients[server]: if result.client is not None:
return server, server_clients[server] return result.server, result.client
return None, None return None, None

View File

@ -1,7 +1,8 @@
import logging import logging
import socket import socket
from contextlib import suppress from contextlib import suppress
from typing import Dict, Iterable, Iterator, Optional from dataclasses import dataclass
from typing import Dict, Iterable, Iterator, List, Optional, Tuple
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT
from common.types import SocketAddress from common.types import SocketAddress
@ -25,38 +26,47 @@ logger = logging.getLogger(__name__)
# practical purposes. Revisit this if it's not. # practical purposes. Revisit this if it's not.
NUM_FIND_SERVER_WORKERS = 32 NUM_FIND_SERVER_WORKERS = 32
IslandAPISearchResults = Dict[str, Optional[IIslandAPIClient]]
@dataclass
class IslandAPISearchResult:
server: SocketAddress
client: Optional[IIslandAPIClient]
IslandAPISearchResults = List[IslandAPISearchResult]
def find_available_island_apis( def find_available_island_apis(
servers: Iterable[str], island_api_client_factory: AbstractIslandAPIClientFactory servers: Iterable[SocketAddress], island_api_client_factory: AbstractIslandAPIClientFactory
) -> IslandAPISearchResults: ) -> IslandAPISearchResults:
server_list = list(servers) server_list = list(servers)
server_iterator = ThreadSafeIterator(server_list.__iter__()) server_iterator = ThreadSafeIterator(enumerate(server_list.__iter__()))
server_results: IslandAPISearchResults = {} results: Dict[int, IslandAPISearchResult] = {}
run_worker_threads( run_worker_threads(
_find_island_server, _find_island_server,
"FindIslandServer", "FindIslandServer",
args=(server_iterator, server_results, island_api_client_factory), args=(server_iterator, results, island_api_client_factory),
num_workers=NUM_FIND_SERVER_WORKERS, num_workers=NUM_FIND_SERVER_WORKERS,
) )
return server_results return [results[i] for i in sorted(results.keys())]
def _find_island_server( def _find_island_server(
servers: Iterator[str], servers: Iterator[Tuple[int, SocketAddress]],
server_results: IslandAPISearchResults, server_results: Dict[int, IslandAPISearchResult],
island_api_client_factory: AbstractIslandAPIClientFactory, island_api_client_factory: AbstractIslandAPIClientFactory,
): ):
with suppress(StopIteration): with suppress(StopIteration):
server = next(servers) index, server = next(servers)
server_results[server] = _check_if_island_server(server, island_api_client_factory) server_results[index] = IslandAPISearchResult(
server, _check_if_island_server(server, island_api_client_factory)
)
def _check_if_island_server( def _check_if_island_server(
server: str, island_api_client_factory: AbstractIslandAPIClientFactory server: SocketAddress, island_api_client_factory: AbstractIslandAPIClientFactory
) -> Optional[IIslandAPIClient]: ) -> Optional[IIslandAPIClient]:
logger.debug(f"Trying to connect to server: {server}") logger.debug(f"Trying to connect to server: {server}")

View File

@ -30,7 +30,7 @@ AGENT_REGISTRATION = AgentRegistrationData(
machine_hardware_id=1, machine_hardware_id=1,
start_time=0, start_time=0,
parent_id=None, parent_id=None,
cc_server=SERVER, cc_server=str(SERVER),
network_interfaces=[], network_interfaces=[],
) )

View File

@ -1,18 +1,21 @@
from typing import Callable, Optional
import pytest import pytest
import requests_mock import requests_mock
from common.agent_event_serializers import AgentEventSerializerRegistry from common.agent_event_serializers import AgentEventSerializerRegistry
from common.types import SocketAddress
from infection_monkey.island_api_client import ( from infection_monkey.island_api_client import (
HTTPIslandAPIClientFactory, HTTPIslandAPIClientFactory,
IIslandAPIClient, IIslandAPIClient,
IslandAPIConnectionError, IslandAPIConnectionError,
) )
from infection_monkey.network.relay.utils import find_available_island_apis from infection_monkey.network.relay.utils import IslandAPISearchResult, find_available_island_apis
SERVER_1 = "1.1.1.1:12312" SERVER_1 = SocketAddress(ip="1.1.1.1", port=12312)
SERVER_2 = "2.2.2.2:4321" SERVER_2 = SocketAddress(ip="2.2.2.2", port=4321)
SERVER_3 = "3.3.3.3:3142" SERVER_3 = SocketAddress(ip="3.3.3.3", port=3142)
SERVER_4 = "4.4.4.4:5000" SERVER_4 = SocketAddress(ip="4.4.4.4", port=5000)
servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4] servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4]
@ -45,11 +48,41 @@ def test_find_available_island_apis(
assert len(available_apis) == len(server_response_pairs) assert len(available_apis) == len(server_response_pairs)
for server, island_api_client in available_apis.items(): for result in available_apis:
if server in expected_available_servers: if result.server in expected_available_servers:
assert island_api_client is not None assert result.client is not None
else: else:
assert island_api_client is None assert result.client is None
def test_find_available_island_apis__preserves_input_order(island_api_client_factory):
available_servers = [SERVER_2, SERVER_3]
with requests_mock.Mocker() as mock:
mock.get(f"https://{SERVER_1}/api?action=is-up", exc=IslandAPIConnectionError)
for server in available_servers:
mock.get(f"https://{server}/api?action=is-up", text="")
available_apis = find_available_island_apis(servers, island_api_client_factory)
for index in range(len(servers)):
assert available_apis[index].server == servers[index]
def _is_none(value) -> bool:
return value is None
def _is_island_client(value) -> bool:
return isinstance(value, IIslandAPIClient)
def _assert_server_and_predicate(
result: IslandAPISearchResult,
server: SocketAddress,
predicate: Callable[[Optional[IIslandAPIClient]], bool],
):
assert result.server == server
assert predicate(result.client)
def test_find_available_island_apis__multiple_successes(island_api_client_factory): def test_find_available_island_apis__multiple_successes(island_api_client_factory):
@ -61,7 +94,7 @@ def test_find_available_island_apis__multiple_successes(island_api_client_factor
available_apis = find_available_island_apis(servers, island_api_client_factory) available_apis = find_available_island_apis(servers, island_api_client_factory)
assert available_apis[SERVER_1] is None _assert_server_and_predicate(available_apis[0], SERVER_1, _is_none)
assert available_apis[SERVER_4] is None _assert_server_and_predicate(available_apis[1], SERVER_2, _is_island_client)
for server in available_servers: _assert_server_and_predicate(available_apis[2], SERVER_3, _is_island_client)
assert isinstance(available_apis[server], IIslandAPIClient) _assert_server_and_predicate(available_apis[3], SERVER_4, _is_none)