forked from p15670423/monkey
Common: Make SocketAddress hashable
This commit is contained in:
parent
c4804f06a9
commit
110542eeb8
|
@ -41,5 +41,8 @@ class SocketAddress(InfectionMonkeyBaseModel):
|
||||||
raise ValueError("SocketAddress requires a port")
|
raise ValueError("SocketAddress requires a port")
|
||||||
return SocketAddress(ip=IPv4Address(ip), port=int(port))
|
return SocketAddress(ip=IPv4Address(ip), port=int(port))
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(str(self))
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"{self.ip}:{self.port}"
|
return f"{self.ip}:{self.port}"
|
||||||
|
|
|
@ -115,7 +115,6 @@ class InfectionMonkey:
|
||||||
|
|
||||||
self._singleton = SystemSingleton()
|
self._singleton = SystemSingleton()
|
||||||
self._opts = self._get_arguments(args)
|
self._opts = self._get_arguments(args)
|
||||||
self._server_strings = [str(s) for s in self._opts.servers]
|
|
||||||
|
|
||||||
self._agent_event_serializer_registry = self._setup_agent_event_serializers()
|
self._agent_event_serializer_registry = self._setup_agent_event_serializers()
|
||||||
|
|
||||||
|
@ -155,7 +154,7 @@ class InfectionMonkey:
|
||||||
|
|
||||||
# TODO: By the time we finish 2292, _connect_to_island_api() may not need to return `server`
|
# TODO: By the time we finish 2292, _connect_to_island_api() may not need to return `server`
|
||||||
def _connect_to_island_api(self) -> Tuple[Optional[str], Optional[IIslandAPIClient]]:
|
def _connect_to_island_api(self) -> Tuple[Optional[str], Optional[IIslandAPIClient]]:
|
||||||
logger.debug(f"Trying to wake up with servers: {', '.join(self._server_strings)}")
|
logger.debug(f"Trying to wake up with servers: {', '.join(map(str, self._opts.servers))}")
|
||||||
server_clients = find_available_island_apis(
|
server_clients = find_available_island_apis(
|
||||||
self._opts.servers, HTTPIslandAPIClientFactory(self._agent_event_serializer_registry)
|
self._opts.servers, HTTPIslandAPIClientFactory(self._agent_event_serializer_registry)
|
||||||
)
|
)
|
||||||
|
@ -166,13 +165,14 @@ class InfectionMonkey:
|
||||||
logger.info(f"Successfully connected to the island via {server}")
|
logger.info(f"Successfully connected to the island via {server}")
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Failed to connect to the island via any known servers: {self._server_strings}"
|
"Failed to connect to the island via any known servers: "
|
||||||
|
f"[{', '.join(map(str, self._opts.servers))}]"
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE: Since we pass the address for each of our interfaces to the exploited
|
# NOTE: Since we pass the address for each of our interfaces to the exploited
|
||||||
# machines, is it possible for a machine to unintentionally unregister itself from the
|
# machines, is it possible for a machine to unintentionally unregister itself from the
|
||||||
# relay if it is able to connect to the relay over multiple interfaces?
|
# relay if it is able to connect to the relay over multiple interfaces?
|
||||||
servers_to_close = (s for s in self._server_strings if s != server and server_clients[s])
|
servers_to_close = (s for s in self._opts.servers if s != server and server_clients[s])
|
||||||
send_remove_from_waitlist_control_message_to_relays(servers_to_close)
|
send_remove_from_waitlist_control_message_to_relays(servers_to_close)
|
||||||
|
|
||||||
return server, island_api_client
|
return server, island_api_client
|
||||||
|
@ -192,9 +192,9 @@ class InfectionMonkey:
|
||||||
def _select_server(
|
def _select_server(
|
||||||
self, server_clients: IslandAPISearchResults
|
self, server_clients: IslandAPISearchResults
|
||||||
) -> Tuple[Optional[SocketAddress], Optional[IIslandAPIClient]]:
|
) -> Tuple[Optional[SocketAddress], Optional[IIslandAPIClient]]:
|
||||||
for result in server_clients:
|
for server in self._opts.servers:
|
||||||
if result.client is not None:
|
if server_clients[server] is not None:
|
||||||
return result.server, result.client
|
return server, server_clients[server]
|
||||||
|
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
@ -261,8 +261,9 @@ class InfectionMonkey:
|
||||||
return agent_event_serializer_registry
|
return agent_event_serializer_registry
|
||||||
|
|
||||||
def _build_server_list(self, relay_port: int):
|
def _build_server_list(self, relay_port: int):
|
||||||
|
my_servers = [str(s) for s in self._opts.servers]
|
||||||
relay_servers = [f"{ip}:{relay_port}" for ip in get_my_ip_addresses()]
|
relay_servers = [f"{ip}:{relay_port}" for ip in get_my_ip_addresses()]
|
||||||
return self._server_strings + relay_servers
|
return my_servers + relay_servers
|
||||||
|
|
||||||
def _build_master(self, relay_port: int):
|
def _build_master(self, relay_port: int):
|
||||||
servers = self._build_server_list(relay_port)
|
servers = self._build_server_list(relay_port)
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from dataclasses import dataclass
|
from typing import Dict, Iterable, Iterator, Optional
|
||||||
from typing import Dict, Iterable, Iterator, List, Optional, Tuple
|
|
||||||
|
|
||||||
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT
|
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT
|
||||||
from common.types import SocketAddress
|
from common.types import SocketAddress
|
||||||
|
@ -27,42 +26,34 @@ logger = logging.getLogger(__name__)
|
||||||
NUM_FIND_SERVER_WORKERS = 32
|
NUM_FIND_SERVER_WORKERS = 32
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
IslandAPISearchResults = Dict[SocketAddress, Optional[IIslandAPIClient]]
|
||||||
class IslandAPISearchResult:
|
|
||||||
server: SocketAddress
|
|
||||||
client: Optional[IIslandAPIClient]
|
|
||||||
|
|
||||||
|
|
||||||
IslandAPISearchResults = List[IslandAPISearchResult]
|
|
||||||
|
|
||||||
|
|
||||||
def find_available_island_apis(
|
def find_available_island_apis(
|
||||||
servers: Iterable[SocketAddress], island_api_client_factory: AbstractIslandAPIClientFactory
|
servers: Iterable[SocketAddress], island_api_client_factory: AbstractIslandAPIClientFactory
|
||||||
) -> IslandAPISearchResults:
|
) -> IslandAPISearchResults:
|
||||||
server_list = list(servers)
|
server_list = list(servers)
|
||||||
server_iterator = ThreadSafeIterator(enumerate(server_list.__iter__()))
|
server_iterator = ThreadSafeIterator(server_list.__iter__())
|
||||||
results: Dict[int, IslandAPISearchResult] = {}
|
server_results: IslandAPISearchResults = {}
|
||||||
|
|
||||||
run_worker_threads(
|
run_worker_threads(
|
||||||
_find_island_server,
|
_find_island_server,
|
||||||
"FindIslandServer",
|
"FindIslandServer",
|
||||||
args=(server_iterator, results, island_api_client_factory),
|
args=(server_iterator, server_results, island_api_client_factory),
|
||||||
num_workers=NUM_FIND_SERVER_WORKERS,
|
num_workers=NUM_FIND_SERVER_WORKERS,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [results[i] for i in sorted(results.keys())]
|
return server_results
|
||||||
|
|
||||||
|
|
||||||
def _find_island_server(
|
def _find_island_server(
|
||||||
servers: Iterator[Tuple[int, SocketAddress]],
|
servers: Iterator[SocketAddress],
|
||||||
server_results: Dict[int, IslandAPISearchResult],
|
server_results: IslandAPISearchResults,
|
||||||
island_api_client_factory: AbstractIslandAPIClientFactory,
|
island_api_client_factory: AbstractIslandAPIClientFactory,
|
||||||
):
|
):
|
||||||
with suppress(StopIteration):
|
with suppress(StopIteration):
|
||||||
index, server = next(servers)
|
server = next(servers)
|
||||||
server_results[index] = IslandAPISearchResult(
|
server_results[server] = _check_if_island_server(server, island_api_client_factory)
|
||||||
server, _check_if_island_server(server, island_api_client_factory)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _check_if_island_server(
|
def _check_if_island_server(
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
from typing import Callable, Optional
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import requests_mock
|
import requests_mock
|
||||||
|
|
||||||
|
@ -10,7 +8,7 @@ from infection_monkey.island_api_client import (
|
||||||
IIslandAPIClient,
|
IIslandAPIClient,
|
||||||
IslandAPIConnectionError,
|
IslandAPIConnectionError,
|
||||||
)
|
)
|
||||||
from infection_monkey.network.relay.utils import IslandAPISearchResult, find_available_island_apis
|
from infection_monkey.network.relay.utils import find_available_island_apis
|
||||||
|
|
||||||
SERVER_1 = SocketAddress(ip="1.1.1.1", port=12312)
|
SERVER_1 = SocketAddress(ip="1.1.1.1", port=12312)
|
||||||
SERVER_2 = SocketAddress(ip="2.2.2.2", port=4321)
|
SERVER_2 = SocketAddress(ip="2.2.2.2", port=4321)
|
||||||
|
@ -48,41 +46,11 @@ def test_find_available_island_apis(
|
||||||
|
|
||||||
assert len(available_apis) == len(server_response_pairs)
|
assert len(available_apis) == len(server_response_pairs)
|
||||||
|
|
||||||
for result in available_apis:
|
for server, island_api_client in available_apis.items():
|
||||||
if result.server in expected_available_servers:
|
if server in expected_available_servers:
|
||||||
assert result.client is not None
|
assert island_api_client is not None
|
||||||
else:
|
else:
|
||||||
assert result.client is None
|
assert island_api_client is None
|
||||||
|
|
||||||
|
|
||||||
def test_find_available_island_apis__preserves_input_order(island_api_client_factory):
|
|
||||||
available_servers = [SERVER_2, SERVER_3]
|
|
||||||
|
|
||||||
with requests_mock.Mocker() as mock:
|
|
||||||
mock.get(f"https://{SERVER_1}/api?action=is-up", exc=IslandAPIConnectionError)
|
|
||||||
for server in available_servers:
|
|
||||||
mock.get(f"https://{server}/api?action=is-up", text="")
|
|
||||||
available_apis = find_available_island_apis(servers, island_api_client_factory)
|
|
||||||
|
|
||||||
for index in range(len(servers)):
|
|
||||||
assert available_apis[index].server == servers[index]
|
|
||||||
|
|
||||||
|
|
||||||
def _is_none(value) -> bool:
|
|
||||||
return value is None
|
|
||||||
|
|
||||||
|
|
||||||
def _is_island_client(value) -> bool:
|
|
||||||
return isinstance(value, IIslandAPIClient)
|
|
||||||
|
|
||||||
|
|
||||||
def _assert_server_and_predicate(
|
|
||||||
result: IslandAPISearchResult,
|
|
||||||
server: SocketAddress,
|
|
||||||
predicate: Callable[[Optional[IIslandAPIClient]], bool],
|
|
||||||
):
|
|
||||||
assert result.server == server
|
|
||||||
assert predicate(result.client)
|
|
||||||
|
|
||||||
|
|
||||||
def test_find_available_island_apis__multiple_successes(island_api_client_factory):
|
def test_find_available_island_apis__multiple_successes(island_api_client_factory):
|
||||||
|
@ -94,7 +62,7 @@ def test_find_available_island_apis__multiple_successes(island_api_client_factor
|
||||||
|
|
||||||
available_apis = find_available_island_apis(servers, island_api_client_factory)
|
available_apis = find_available_island_apis(servers, island_api_client_factory)
|
||||||
|
|
||||||
_assert_server_and_predicate(available_apis[0], SERVER_1, _is_none)
|
assert available_apis[SERVER_1] is None
|
||||||
_assert_server_and_predicate(available_apis[1], SERVER_2, _is_island_client)
|
assert available_apis[SERVER_4] is None
|
||||||
_assert_server_and_predicate(available_apis[2], SERVER_3, _is_island_client)
|
for server in available_servers:
|
||||||
_assert_server_and_predicate(available_apis[3], SERVER_4, _is_none)
|
assert isinstance(available_apis[server], IIslandAPIClient)
|
||||||
|
|
Loading…
Reference in New Issue