Agent: Use SocketAddress in IIslandAPIClient

This commit is contained in:
Kekoa Kaaikala 2022-09-23 21:28:39 +00:00 committed by Shreya Malviya
parent 90890106f7
commit c4804f06a9
6 changed files with 82 additions and 37 deletions

View File

@ -16,6 +16,7 @@ from common.common_consts.timeouts import (
SHORT_REQUEST_TIMEOUT,
)
from common.credentials import Credentials
from common.types import SocketAddress
from . import (
AbstractIslandAPIClientFactory,
@ -79,7 +80,7 @@ class HTTPIslandAPIClient(IIslandAPIClient):
@handle_island_errors
def connect(
self,
island_server: str,
island_server: SocketAddress,
):
response = requests.get( # noqa: DUO123
f"https://{island_server}/api?action=is-up",
@ -88,8 +89,7 @@ class HTTPIslandAPIClient(IIslandAPIClient):
)
response.raise_for_status()
self._island_server = island_server
self._api_url = f"https://{self._island_server}/api"
self._api_url = f"https://{island_server}/api"
@handle_island_errors
def send_log(self, log_contents: str):

View File

@ -5,6 +5,7 @@ from common import AgentRegistrationData, AgentSignals, OperatingSystem
from common.agent_configuration import AgentConfiguration
from common.agent_events import AbstractAgentEvent
from common.credentials import Credentials
from common.types import SocketAddress
class IIslandAPIClient(ABC):
@ -13,7 +14,7 @@ class IIslandAPIClient(ABC):
"""
@abstractmethod
def connect(self, island_server: str):
def connect(self, island_server: SocketAddress):
"""
Connect to the island's API

View File

@ -5,7 +5,7 @@ import subprocess
import sys
from ipaddress import IPv4Interface
from pathlib import Path, WindowsPath
from typing import List, Mapping, Optional, Tuple
from typing import List, Optional, Tuple
from pubsub.core import Publisher
@ -55,6 +55,7 @@ from infection_monkey.network.firewall import app as firewall
from infection_monkey.network.info import get_free_tcp_port
from infection_monkey.network.relay import TCPRelay
from infection_monkey.network.relay.utils import (
IslandAPISearchResults,
find_available_island_apis,
notify_disconnect,
send_remove_from_waitlist_control_message_to_relays,
@ -156,7 +157,7 @@ class InfectionMonkey:
def _connect_to_island_api(self) -> Tuple[Optional[str], Optional[IIslandAPIClient]]:
logger.debug(f"Trying to wake up with servers: {', '.join(self._server_strings)}")
server_clients = find_available_island_apis(
self._server_strings, HTTPIslandAPIClientFactory(self._agent_event_serializer_registry)
self._opts.servers, HTTPIslandAPIClientFactory(self._agent_event_serializer_registry)
)
server, island_api_client = self._select_server(server_clients)
@ -189,11 +190,11 @@ class InfectionMonkey:
self._island_api_client.register_agent(agent_registration_data)
def _select_server(
self, server_clients: Mapping[str, Optional[IIslandAPIClient]]
) -> Tuple[Optional[str], Optional[IIslandAPIClient]]:
for server in self._server_strings:
if server_clients[server]:
return server, server_clients[server]
self, server_clients: IslandAPISearchResults
) -> Tuple[Optional[SocketAddress], Optional[IIslandAPIClient]]:
for result in server_clients:
if result.client is not None:
return result.server, result.client
return None, None

View File

@ -1,7 +1,8 @@
import logging
import socket
from contextlib import suppress
from typing import Dict, Iterable, Iterator, Optional
from dataclasses import dataclass
from typing import Dict, Iterable, Iterator, List, Optional, Tuple
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT
from common.types import SocketAddress
@ -25,38 +26,47 @@ logger = logging.getLogger(__name__)
# practical purposes. Revisit this if it's not.
NUM_FIND_SERVER_WORKERS = 32
IslandAPISearchResults = Dict[str, Optional[IIslandAPIClient]]
@dataclass
class IslandAPISearchResult:
server: SocketAddress
client: Optional[IIslandAPIClient]
IslandAPISearchResults = List[IslandAPISearchResult]
def find_available_island_apis(
servers: Iterable[str], island_api_client_factory: AbstractIslandAPIClientFactory
servers: Iterable[SocketAddress], island_api_client_factory: AbstractIslandAPIClientFactory
) -> IslandAPISearchResults:
server_list = list(servers)
server_iterator = ThreadSafeIterator(server_list.__iter__())
server_results: IslandAPISearchResults = {}
server_iterator = ThreadSafeIterator(enumerate(server_list.__iter__()))
results: Dict[int, IslandAPISearchResult] = {}
run_worker_threads(
_find_island_server,
"FindIslandServer",
args=(server_iterator, server_results, island_api_client_factory),
args=(server_iterator, results, island_api_client_factory),
num_workers=NUM_FIND_SERVER_WORKERS,
)
return server_results
return [results[i] for i in sorted(results.keys())]
def _find_island_server(
servers: Iterator[str],
server_results: IslandAPISearchResults,
servers: Iterator[Tuple[int, SocketAddress]],
server_results: Dict[int, IslandAPISearchResult],
island_api_client_factory: AbstractIslandAPIClientFactory,
):
with suppress(StopIteration):
server = next(servers)
server_results[server] = _check_if_island_server(server, island_api_client_factory)
index, server = next(servers)
server_results[index] = IslandAPISearchResult(
server, _check_if_island_server(server, island_api_client_factory)
)
def _check_if_island_server(
server: str, island_api_client_factory: AbstractIslandAPIClientFactory
server: SocketAddress, island_api_client_factory: AbstractIslandAPIClientFactory
) -> Optional[IIslandAPIClient]:
logger.debug(f"Trying to connect to server: {server}")

View File

@ -30,7 +30,7 @@ AGENT_REGISTRATION = AgentRegistrationData(
machine_hardware_id=1,
start_time=0,
parent_id=None,
cc_server=SERVER,
cc_server=str(SERVER),
network_interfaces=[],
)

View File

@ -1,18 +1,21 @@
from typing import Callable, Optional
import pytest
import requests_mock
from common.agent_event_serializers import AgentEventSerializerRegistry
from common.types import SocketAddress
from infection_monkey.island_api_client import (
HTTPIslandAPIClientFactory,
IIslandAPIClient,
IslandAPIConnectionError,
)
from infection_monkey.network.relay.utils import find_available_island_apis
from infection_monkey.network.relay.utils import IslandAPISearchResult, find_available_island_apis
SERVER_1 = "1.1.1.1:12312"
SERVER_2 = "2.2.2.2:4321"
SERVER_3 = "3.3.3.3:3142"
SERVER_4 = "4.4.4.4:5000"
SERVER_1 = SocketAddress(ip="1.1.1.1", port=12312)
SERVER_2 = SocketAddress(ip="2.2.2.2", port=4321)
SERVER_3 = SocketAddress(ip="3.3.3.3", port=3142)
SERVER_4 = SocketAddress(ip="4.4.4.4", port=5000)
servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4]
@ -45,11 +48,41 @@ def test_find_available_island_apis(
assert len(available_apis) == len(server_response_pairs)
for server, island_api_client in available_apis.items():
if server in expected_available_servers:
assert island_api_client is not None
for result in available_apis:
if result.server in expected_available_servers:
assert result.client is not None
else:
assert island_api_client is None
assert result.client is None
def test_find_available_island_apis__preserves_input_order(island_api_client_factory):
available_servers = [SERVER_2, SERVER_3]
with requests_mock.Mocker() as mock:
mock.get(f"https://{SERVER_1}/api?action=is-up", exc=IslandAPIConnectionError)
for server in available_servers:
mock.get(f"https://{server}/api?action=is-up", text="")
available_apis = find_available_island_apis(servers, island_api_client_factory)
for index in range(len(servers)):
assert available_apis[index].server == servers[index]
def _is_none(value) -> bool:
return value is None
def _is_island_client(value) -> bool:
return isinstance(value, IIslandAPIClient)
def _assert_server_and_predicate(
result: IslandAPISearchResult,
server: SocketAddress,
predicate: Callable[[Optional[IIslandAPIClient]], bool],
):
assert result.server == server
assert predicate(result.client)
def test_find_available_island_apis__multiple_successes(island_api_client_factory):
@ -61,7 +94,7 @@ def test_find_available_island_apis__multiple_successes(island_api_client_factor
available_apis = find_available_island_apis(servers, island_api_client_factory)
assert available_apis[SERVER_1] is None
assert available_apis[SERVER_4] is None
for server in available_servers:
assert isinstance(available_apis[server], IIslandAPIClient)
_assert_server_and_predicate(available_apis[0], SERVER_1, _is_none)
_assert_server_and_predicate(available_apis[1], SERVER_2, _is_island_client)
_assert_server_and_predicate(available_apis[2], SERVER_3, _is_island_client)
_assert_server_and_predicate(available_apis[3], SERVER_4, _is_none)