Common: Make SocketAddress hashable

This commit is contained in:
Kekoa Kaaikala 2022-09-26 13:27:08 +00:00 committed by Shreya Malviya
parent c4804f06a9
commit 110542eeb8
4 changed files with 31 additions and 68 deletions

View File

@ -41,5 +41,8 @@ class SocketAddress(InfectionMonkeyBaseModel):
raise ValueError("SocketAddress requires a port") raise ValueError("SocketAddress requires a port")
return SocketAddress(ip=IPv4Address(ip), port=int(port)) return SocketAddress(ip=IPv4Address(ip), port=int(port))
def __hash__(self):
return hash(str(self))
def __str__(self): def __str__(self):
return f"{self.ip}:{self.port}" return f"{self.ip}:{self.port}"

View File

@ -115,7 +115,6 @@ class InfectionMonkey:
self._singleton = SystemSingleton() self._singleton = SystemSingleton()
self._opts = self._get_arguments(args) self._opts = self._get_arguments(args)
self._server_strings = [str(s) for s in self._opts.servers]
self._agent_event_serializer_registry = self._setup_agent_event_serializers() self._agent_event_serializer_registry = self._setup_agent_event_serializers()
@ -155,7 +154,7 @@ class InfectionMonkey:
# TODO: By the time we finish 2292, _connect_to_island_api() may not need to return `server` # TODO: By the time we finish 2292, _connect_to_island_api() may not need to return `server`
def _connect_to_island_api(self) -> Tuple[Optional[str], Optional[IIslandAPIClient]]: def _connect_to_island_api(self) -> Tuple[Optional[str], Optional[IIslandAPIClient]]:
logger.debug(f"Trying to wake up with servers: {', '.join(self._server_strings)}") logger.debug(f"Trying to wake up with servers: {', '.join(map(str, self._opts.servers))}")
server_clients = find_available_island_apis( server_clients = find_available_island_apis(
self._opts.servers, HTTPIslandAPIClientFactory(self._agent_event_serializer_registry) self._opts.servers, HTTPIslandAPIClientFactory(self._agent_event_serializer_registry)
) )
@ -166,13 +165,14 @@ class InfectionMonkey:
logger.info(f"Successfully connected to the island via {server}") logger.info(f"Successfully connected to the island via {server}")
else: else:
raise Exception( raise Exception(
f"Failed to connect to the island via any known servers: {self._server_strings}" "Failed to connect to the island via any known servers: "
f"[{', '.join(map(str, self._opts.servers))}]"
) )
# NOTE: Since we pass the address for each of our interfaces to the exploited # NOTE: Since we pass the address for each of our interfaces to the exploited
# machines, is it possible for a machine to unintentionally unregister itself from the # machines, is it possible for a machine to unintentionally unregister itself from the
# relay if it is able to connect to the relay over multiple interfaces? # relay if it is able to connect to the relay over multiple interfaces?
servers_to_close = (s for s in self._server_strings if s != server and server_clients[s]) servers_to_close = (s for s in self._opts.servers if s != server and server_clients[s])
send_remove_from_waitlist_control_message_to_relays(servers_to_close) send_remove_from_waitlist_control_message_to_relays(servers_to_close)
return server, island_api_client return server, island_api_client
@ -192,9 +192,9 @@ class InfectionMonkey:
def _select_server( def _select_server(
self, server_clients: IslandAPISearchResults self, server_clients: IslandAPISearchResults
) -> Tuple[Optional[SocketAddress], Optional[IIslandAPIClient]]: ) -> Tuple[Optional[SocketAddress], Optional[IIslandAPIClient]]:
for result in server_clients: for server in self._opts.servers:
if result.client is not None: if server_clients[server] is not None:
return result.server, result.client return server, server_clients[server]
return None, None return None, None
@ -261,8 +261,9 @@ class InfectionMonkey:
return agent_event_serializer_registry return agent_event_serializer_registry
def _build_server_list(self, relay_port: int): def _build_server_list(self, relay_port: int):
my_servers = [str(s) for s in self._opts.servers]
relay_servers = [f"{ip}:{relay_port}" for ip in get_my_ip_addresses()] relay_servers = [f"{ip}:{relay_port}" for ip in get_my_ip_addresses()]
return self._server_strings + relay_servers return my_servers + relay_servers
def _build_master(self, relay_port: int): def _build_master(self, relay_port: int):
servers = self._build_server_list(relay_port) servers = self._build_server_list(relay_port)

View File

@ -1,8 +1,7 @@
import logging import logging
import socket import socket
from contextlib import suppress from contextlib import suppress
from dataclasses import dataclass from typing import Dict, Iterable, Iterator, Optional
from typing import Dict, Iterable, Iterator, List, Optional, Tuple
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT
from common.types import SocketAddress from common.types import SocketAddress
@ -27,42 +26,34 @@ logger = logging.getLogger(__name__)
NUM_FIND_SERVER_WORKERS = 32 NUM_FIND_SERVER_WORKERS = 32
@dataclass IslandAPISearchResults = Dict[SocketAddress, Optional[IIslandAPIClient]]
class IslandAPISearchResult:
server: SocketAddress
client: Optional[IIslandAPIClient]
IslandAPISearchResults = List[IslandAPISearchResult]
def find_available_island_apis( def find_available_island_apis(
servers: Iterable[SocketAddress], island_api_client_factory: AbstractIslandAPIClientFactory servers: Iterable[SocketAddress], island_api_client_factory: AbstractIslandAPIClientFactory
) -> IslandAPISearchResults: ) -> IslandAPISearchResults:
server_list = list(servers) server_list = list(servers)
server_iterator = ThreadSafeIterator(enumerate(server_list.__iter__())) server_iterator = ThreadSafeIterator(server_list.__iter__())
results: Dict[int, IslandAPISearchResult] = {} server_results: IslandAPISearchResults = {}
run_worker_threads( run_worker_threads(
_find_island_server, _find_island_server,
"FindIslandServer", "FindIslandServer",
args=(server_iterator, results, island_api_client_factory), args=(server_iterator, server_results, island_api_client_factory),
num_workers=NUM_FIND_SERVER_WORKERS, num_workers=NUM_FIND_SERVER_WORKERS,
) )
return [results[i] for i in sorted(results.keys())] return server_results
def _find_island_server( def _find_island_server(
servers: Iterator[Tuple[int, SocketAddress]], servers: Iterator[SocketAddress],
server_results: Dict[int, IslandAPISearchResult], server_results: IslandAPISearchResults,
island_api_client_factory: AbstractIslandAPIClientFactory, island_api_client_factory: AbstractIslandAPIClientFactory,
): ):
with suppress(StopIteration): with suppress(StopIteration):
index, server = next(servers) server = next(servers)
server_results[index] = IslandAPISearchResult( server_results[server] = _check_if_island_server(server, island_api_client_factory)
server, _check_if_island_server(server, island_api_client_factory)
)
def _check_if_island_server( def _check_if_island_server(

View File

@ -1,5 +1,3 @@
from typing import Callable, Optional
import pytest import pytest
import requests_mock import requests_mock
@ -10,7 +8,7 @@ from infection_monkey.island_api_client import (
IIslandAPIClient, IIslandAPIClient,
IslandAPIConnectionError, IslandAPIConnectionError,
) )
from infection_monkey.network.relay.utils import IslandAPISearchResult, find_available_island_apis from infection_monkey.network.relay.utils import find_available_island_apis
SERVER_1 = SocketAddress(ip="1.1.1.1", port=12312) SERVER_1 = SocketAddress(ip="1.1.1.1", port=12312)
SERVER_2 = SocketAddress(ip="2.2.2.2", port=4321) SERVER_2 = SocketAddress(ip="2.2.2.2", port=4321)
@ -48,41 +46,11 @@ def test_find_available_island_apis(
assert len(available_apis) == len(server_response_pairs) assert len(available_apis) == len(server_response_pairs)
for result in available_apis: for server, island_api_client in available_apis.items():
if result.server in expected_available_servers: if server in expected_available_servers:
assert result.client is not None assert island_api_client is not None
else: else:
assert result.client is None assert island_api_client is None
def test_find_available_island_apis__preserves_input_order(island_api_client_factory):
available_servers = [SERVER_2, SERVER_3]
with requests_mock.Mocker() as mock:
mock.get(f"https://{SERVER_1}/api?action=is-up", exc=IslandAPIConnectionError)
for server in available_servers:
mock.get(f"https://{server}/api?action=is-up", text="")
available_apis = find_available_island_apis(servers, island_api_client_factory)
for index in range(len(servers)):
assert available_apis[index].server == servers[index]
def _is_none(value) -> bool:
return value is None
def _is_island_client(value) -> bool:
return isinstance(value, IIslandAPIClient)
def _assert_server_and_predicate(
result: IslandAPISearchResult,
server: SocketAddress,
predicate: Callable[[Optional[IIslandAPIClient]], bool],
):
assert result.server == server
assert predicate(result.client)
def test_find_available_island_apis__multiple_successes(island_api_client_factory): def test_find_available_island_apis__multiple_successes(island_api_client_factory):
@ -94,7 +62,7 @@ def test_find_available_island_apis__multiple_successes(island_api_client_factor
available_apis = find_available_island_apis(servers, island_api_client_factory) available_apis = find_available_island_apis(servers, island_api_client_factory)
_assert_server_and_predicate(available_apis[0], SERVER_1, _is_none) assert available_apis[SERVER_1] is None
_assert_server_and_predicate(available_apis[1], SERVER_2, _is_island_client) assert available_apis[SERVER_4] is None
_assert_server_and_predicate(available_apis[2], SERVER_3, _is_island_client) for server in available_servers:
_assert_server_and_predicate(available_apis[3], SERVER_4, _is_none) assert isinstance(available_apis[server], IIslandAPIClient)