forked from p15670423/monkey
Agent: Use SocketAddress in IIslandAPIClient
This commit is contained in:
parent
90890106f7
commit
c4804f06a9
|
@ -16,6 +16,7 @@ from common.common_consts.timeouts import (
|
||||||
SHORT_REQUEST_TIMEOUT,
|
SHORT_REQUEST_TIMEOUT,
|
||||||
)
|
)
|
||||||
from common.credentials import Credentials
|
from common.credentials import Credentials
|
||||||
|
from common.types import SocketAddress
|
||||||
|
|
||||||
from . import (
|
from . import (
|
||||||
AbstractIslandAPIClientFactory,
|
AbstractIslandAPIClientFactory,
|
||||||
|
@ -79,7 +80,7 @@ class HTTPIslandAPIClient(IIslandAPIClient):
|
||||||
@handle_island_errors
|
@handle_island_errors
|
||||||
def connect(
|
def connect(
|
||||||
self,
|
self,
|
||||||
island_server: str,
|
island_server: SocketAddress,
|
||||||
):
|
):
|
||||||
response = requests.get( # noqa: DUO123
|
response = requests.get( # noqa: DUO123
|
||||||
f"https://{island_server}/api?action=is-up",
|
f"https://{island_server}/api?action=is-up",
|
||||||
|
@ -88,8 +89,7 @@ class HTTPIslandAPIClient(IIslandAPIClient):
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
self._island_server = island_server
|
self._api_url = f"https://{island_server}/api"
|
||||||
self._api_url = f"https://{self._island_server}/api"
|
|
||||||
|
|
||||||
@handle_island_errors
|
@handle_island_errors
|
||||||
def send_log(self, log_contents: str):
|
def send_log(self, log_contents: str):
|
||||||
|
|
|
@ -5,6 +5,7 @@ from common import AgentRegistrationData, AgentSignals, OperatingSystem
|
||||||
from common.agent_configuration import AgentConfiguration
|
from common.agent_configuration import AgentConfiguration
|
||||||
from common.agent_events import AbstractAgentEvent
|
from common.agent_events import AbstractAgentEvent
|
||||||
from common.credentials import Credentials
|
from common.credentials import Credentials
|
||||||
|
from common.types import SocketAddress
|
||||||
|
|
||||||
|
|
||||||
class IIslandAPIClient(ABC):
|
class IIslandAPIClient(ABC):
|
||||||
|
@ -13,7 +14,7 @@ class IIslandAPIClient(ABC):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def connect(self, island_server: str):
|
def connect(self, island_server: SocketAddress):
|
||||||
"""
|
"""
|
||||||
Connect to the island's API
|
Connect to the island's API
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ import subprocess
|
||||||
import sys
|
import sys
|
||||||
from ipaddress import IPv4Interface
|
from ipaddress import IPv4Interface
|
||||||
from pathlib import Path, WindowsPath
|
from pathlib import Path, WindowsPath
|
||||||
from typing import List, Mapping, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from pubsub.core import Publisher
|
from pubsub.core import Publisher
|
||||||
|
|
||||||
|
@ -55,6 +55,7 @@ from infection_monkey.network.firewall import app as firewall
|
||||||
from infection_monkey.network.info import get_free_tcp_port
|
from infection_monkey.network.info import get_free_tcp_port
|
||||||
from infection_monkey.network.relay import TCPRelay
|
from infection_monkey.network.relay import TCPRelay
|
||||||
from infection_monkey.network.relay.utils import (
|
from infection_monkey.network.relay.utils import (
|
||||||
|
IslandAPISearchResults,
|
||||||
find_available_island_apis,
|
find_available_island_apis,
|
||||||
notify_disconnect,
|
notify_disconnect,
|
||||||
send_remove_from_waitlist_control_message_to_relays,
|
send_remove_from_waitlist_control_message_to_relays,
|
||||||
|
@ -156,7 +157,7 @@ class InfectionMonkey:
|
||||||
def _connect_to_island_api(self) -> Tuple[Optional[str], Optional[IIslandAPIClient]]:
|
def _connect_to_island_api(self) -> Tuple[Optional[str], Optional[IIslandAPIClient]]:
|
||||||
logger.debug(f"Trying to wake up with servers: {', '.join(self._server_strings)}")
|
logger.debug(f"Trying to wake up with servers: {', '.join(self._server_strings)}")
|
||||||
server_clients = find_available_island_apis(
|
server_clients = find_available_island_apis(
|
||||||
self._server_strings, HTTPIslandAPIClientFactory(self._agent_event_serializer_registry)
|
self._opts.servers, HTTPIslandAPIClientFactory(self._agent_event_serializer_registry)
|
||||||
)
|
)
|
||||||
|
|
||||||
server, island_api_client = self._select_server(server_clients)
|
server, island_api_client = self._select_server(server_clients)
|
||||||
|
@ -189,11 +190,11 @@ class InfectionMonkey:
|
||||||
self._island_api_client.register_agent(agent_registration_data)
|
self._island_api_client.register_agent(agent_registration_data)
|
||||||
|
|
||||||
def _select_server(
|
def _select_server(
|
||||||
self, server_clients: Mapping[str, Optional[IIslandAPIClient]]
|
self, server_clients: IslandAPISearchResults
|
||||||
) -> Tuple[Optional[str], Optional[IIslandAPIClient]]:
|
) -> Tuple[Optional[SocketAddress], Optional[IIslandAPIClient]]:
|
||||||
for server in self._server_strings:
|
for result in server_clients:
|
||||||
if server_clients[server]:
|
if result.client is not None:
|
||||||
return server, server_clients[server]
|
return result.server, result.client
|
||||||
|
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from typing import Dict, Iterable, Iterator, Optional
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Iterable, Iterator, List, Optional, Tuple
|
||||||
|
|
||||||
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT
|
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT
|
||||||
from common.types import SocketAddress
|
from common.types import SocketAddress
|
||||||
|
@ -25,38 +26,47 @@ logger = logging.getLogger(__name__)
|
||||||
# practical purposes. Revisit this if it's not.
|
# practical purposes. Revisit this if it's not.
|
||||||
NUM_FIND_SERVER_WORKERS = 32
|
NUM_FIND_SERVER_WORKERS = 32
|
||||||
|
|
||||||
IslandAPISearchResults = Dict[str, Optional[IIslandAPIClient]]
|
|
||||||
|
@dataclass
|
||||||
|
class IslandAPISearchResult:
|
||||||
|
server: SocketAddress
|
||||||
|
client: Optional[IIslandAPIClient]
|
||||||
|
|
||||||
|
|
||||||
|
IslandAPISearchResults = List[IslandAPISearchResult]
|
||||||
|
|
||||||
|
|
||||||
def find_available_island_apis(
|
def find_available_island_apis(
|
||||||
servers: Iterable[str], island_api_client_factory: AbstractIslandAPIClientFactory
|
servers: Iterable[SocketAddress], island_api_client_factory: AbstractIslandAPIClientFactory
|
||||||
) -> IslandAPISearchResults:
|
) -> IslandAPISearchResults:
|
||||||
server_list = list(servers)
|
server_list = list(servers)
|
||||||
server_iterator = ThreadSafeIterator(server_list.__iter__())
|
server_iterator = ThreadSafeIterator(enumerate(server_list.__iter__()))
|
||||||
server_results: IslandAPISearchResults = {}
|
results: Dict[int, IslandAPISearchResult] = {}
|
||||||
|
|
||||||
run_worker_threads(
|
run_worker_threads(
|
||||||
_find_island_server,
|
_find_island_server,
|
||||||
"FindIslandServer",
|
"FindIslandServer",
|
||||||
args=(server_iterator, server_results, island_api_client_factory),
|
args=(server_iterator, results, island_api_client_factory),
|
||||||
num_workers=NUM_FIND_SERVER_WORKERS,
|
num_workers=NUM_FIND_SERVER_WORKERS,
|
||||||
)
|
)
|
||||||
|
|
||||||
return server_results
|
return [results[i] for i in sorted(results.keys())]
|
||||||
|
|
||||||
|
|
||||||
def _find_island_server(
|
def _find_island_server(
|
||||||
servers: Iterator[str],
|
servers: Iterator[Tuple[int, SocketAddress]],
|
||||||
server_results: IslandAPISearchResults,
|
server_results: Dict[int, IslandAPISearchResult],
|
||||||
island_api_client_factory: AbstractIslandAPIClientFactory,
|
island_api_client_factory: AbstractIslandAPIClientFactory,
|
||||||
):
|
):
|
||||||
with suppress(StopIteration):
|
with suppress(StopIteration):
|
||||||
server = next(servers)
|
index, server = next(servers)
|
||||||
server_results[server] = _check_if_island_server(server, island_api_client_factory)
|
server_results[index] = IslandAPISearchResult(
|
||||||
|
server, _check_if_island_server(server, island_api_client_factory)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _check_if_island_server(
|
def _check_if_island_server(
|
||||||
server: str, island_api_client_factory: AbstractIslandAPIClientFactory
|
server: SocketAddress, island_api_client_factory: AbstractIslandAPIClientFactory
|
||||||
) -> Optional[IIslandAPIClient]:
|
) -> Optional[IIslandAPIClient]:
|
||||||
logger.debug(f"Trying to connect to server: {server}")
|
logger.debug(f"Trying to connect to server: {server}")
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ AGENT_REGISTRATION = AgentRegistrationData(
|
||||||
machine_hardware_id=1,
|
machine_hardware_id=1,
|
||||||
start_time=0,
|
start_time=0,
|
||||||
parent_id=None,
|
parent_id=None,
|
||||||
cc_server=SERVER,
|
cc_server=str(SERVER),
|
||||||
network_interfaces=[],
|
network_interfaces=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,18 +1,21 @@
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import requests_mock
|
import requests_mock
|
||||||
|
|
||||||
from common.agent_event_serializers import AgentEventSerializerRegistry
|
from common.agent_event_serializers import AgentEventSerializerRegistry
|
||||||
|
from common.types import SocketAddress
|
||||||
from infection_monkey.island_api_client import (
|
from infection_monkey.island_api_client import (
|
||||||
HTTPIslandAPIClientFactory,
|
HTTPIslandAPIClientFactory,
|
||||||
IIslandAPIClient,
|
IIslandAPIClient,
|
||||||
IslandAPIConnectionError,
|
IslandAPIConnectionError,
|
||||||
)
|
)
|
||||||
from infection_monkey.network.relay.utils import find_available_island_apis
|
from infection_monkey.network.relay.utils import IslandAPISearchResult, find_available_island_apis
|
||||||
|
|
||||||
SERVER_1 = "1.1.1.1:12312"
|
SERVER_1 = SocketAddress(ip="1.1.1.1", port=12312)
|
||||||
SERVER_2 = "2.2.2.2:4321"
|
SERVER_2 = SocketAddress(ip="2.2.2.2", port=4321)
|
||||||
SERVER_3 = "3.3.3.3:3142"
|
SERVER_3 = SocketAddress(ip="3.3.3.3", port=3142)
|
||||||
SERVER_4 = "4.4.4.4:5000"
|
SERVER_4 = SocketAddress(ip="4.4.4.4", port=5000)
|
||||||
|
|
||||||
|
|
||||||
servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4]
|
servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4]
|
||||||
|
@ -45,11 +48,41 @@ def test_find_available_island_apis(
|
||||||
|
|
||||||
assert len(available_apis) == len(server_response_pairs)
|
assert len(available_apis) == len(server_response_pairs)
|
||||||
|
|
||||||
for server, island_api_client in available_apis.items():
|
for result in available_apis:
|
||||||
if server in expected_available_servers:
|
if result.server in expected_available_servers:
|
||||||
assert island_api_client is not None
|
assert result.client is not None
|
||||||
else:
|
else:
|
||||||
assert island_api_client is None
|
assert result.client is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_available_island_apis__preserves_input_order(island_api_client_factory):
|
||||||
|
available_servers = [SERVER_2, SERVER_3]
|
||||||
|
|
||||||
|
with requests_mock.Mocker() as mock:
|
||||||
|
mock.get(f"https://{SERVER_1}/api?action=is-up", exc=IslandAPIConnectionError)
|
||||||
|
for server in available_servers:
|
||||||
|
mock.get(f"https://{server}/api?action=is-up", text="")
|
||||||
|
available_apis = find_available_island_apis(servers, island_api_client_factory)
|
||||||
|
|
||||||
|
for index in range(len(servers)):
|
||||||
|
assert available_apis[index].server == servers[index]
|
||||||
|
|
||||||
|
|
||||||
|
def _is_none(value) -> bool:
|
||||||
|
return value is None
|
||||||
|
|
||||||
|
|
||||||
|
def _is_island_client(value) -> bool:
|
||||||
|
return isinstance(value, IIslandAPIClient)
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_server_and_predicate(
|
||||||
|
result: IslandAPISearchResult,
|
||||||
|
server: SocketAddress,
|
||||||
|
predicate: Callable[[Optional[IIslandAPIClient]], bool],
|
||||||
|
):
|
||||||
|
assert result.server == server
|
||||||
|
assert predicate(result.client)
|
||||||
|
|
||||||
|
|
||||||
def test_find_available_island_apis__multiple_successes(island_api_client_factory):
|
def test_find_available_island_apis__multiple_successes(island_api_client_factory):
|
||||||
|
@ -61,7 +94,7 @@ def test_find_available_island_apis__multiple_successes(island_api_client_factor
|
||||||
|
|
||||||
available_apis = find_available_island_apis(servers, island_api_client_factory)
|
available_apis = find_available_island_apis(servers, island_api_client_factory)
|
||||||
|
|
||||||
assert available_apis[SERVER_1] is None
|
_assert_server_and_predicate(available_apis[0], SERVER_1, _is_none)
|
||||||
assert available_apis[SERVER_4] is None
|
_assert_server_and_predicate(available_apis[1], SERVER_2, _is_island_client)
|
||||||
for server in available_servers:
|
_assert_server_and_predicate(available_apis[2], SERVER_3, _is_island_client)
|
||||||
assert isinstance(available_apis[server], IIslandAPIClient)
|
_assert_server_and_predicate(available_apis[3], SERVER_4, _is_none)
|
||||||
|
|
Loading…
Reference in New Issue