Merge pull request #2351 from guardicore/2323-update-iislandapiclient-with-socketaddress

2323 update iislandapiclient with socketaddress
This commit is contained in:
Kekoa Kaaikala 2022-09-27 08:17:17 -04:00 committed by GitHub
commit 21cbf8d38b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 38 additions and 32 deletions

View File

@ -41,5 +41,8 @@ class SocketAddress(InfectionMonkeyBaseModel):
raise ValueError("SocketAddress requires a port") raise ValueError("SocketAddress requires a port")
return SocketAddress(ip=IPv4Address(ip), port=int(port)) return SocketAddress(ip=IPv4Address(ip), port=int(port))
def __hash__(self):
return hash(str(self))
def __str__(self): def __str__(self):
return f"{self.ip}:{self.port}" return f"{self.ip}:{self.port}"

View File

@ -16,6 +16,7 @@ from common.common_consts.timeouts import (
SHORT_REQUEST_TIMEOUT, SHORT_REQUEST_TIMEOUT,
) )
from common.credentials import Credentials from common.credentials import Credentials
from common.types import SocketAddress
from . import ( from . import (
AbstractIslandAPIClientFactory, AbstractIslandAPIClientFactory,
@ -79,7 +80,7 @@ class HTTPIslandAPIClient(IIslandAPIClient):
@handle_island_errors @handle_island_errors
def connect( def connect(
self, self,
island_server: str, island_server: SocketAddress,
): ):
response = requests.get( # noqa: DUO123 response = requests.get( # noqa: DUO123
f"https://{island_server}/api?action=is-up", f"https://{island_server}/api?action=is-up",
@ -88,8 +89,7 @@ class HTTPIslandAPIClient(IIslandAPIClient):
) )
response.raise_for_status() response.raise_for_status()
self._island_server = island_server self._api_url = f"https://{island_server}/api"
self._api_url = f"https://{self._island_server}/api"
@handle_island_errors @handle_island_errors
def send_log(self, log_contents: str): def send_log(self, log_contents: str):

View File

@ -5,6 +5,7 @@ from common import AgentRegistrationData, AgentSignals, OperatingSystem
from common.agent_configuration import AgentConfiguration from common.agent_configuration import AgentConfiguration
from common.agent_events import AbstractAgentEvent from common.agent_events import AbstractAgentEvent
from common.credentials import Credentials from common.credentials import Credentials
from common.types import SocketAddress
class IIslandAPIClient(ABC): class IIslandAPIClient(ABC):
@ -13,7 +14,7 @@ class IIslandAPIClient(ABC):
""" """
@abstractmethod @abstractmethod
def connect(self, island_server: str): def connect(self, island_server: SocketAddress):
""" """
Connect to the island's API Connect to the island's API

View File

@ -5,7 +5,7 @@ import subprocess
import sys import sys
from ipaddress import IPv4Interface from ipaddress import IPv4Interface
from pathlib import Path, WindowsPath from pathlib import Path, WindowsPath
from typing import List, Mapping, Optional, Tuple from typing import List, Optional, Tuple
from pubsub.core import Publisher from pubsub.core import Publisher
@ -55,6 +55,7 @@ from infection_monkey.network.firewall import app as firewall
from infection_monkey.network.info import get_free_tcp_port from infection_monkey.network.info import get_free_tcp_port
from infection_monkey.network.relay import TCPRelay from infection_monkey.network.relay import TCPRelay
from infection_monkey.network.relay.utils import ( from infection_monkey.network.relay.utils import (
IslandAPISearchResults,
find_available_island_apis, find_available_island_apis,
notify_disconnect, notify_disconnect,
send_remove_from_waitlist_control_message_to_relays, send_remove_from_waitlist_control_message_to_relays,
@ -114,19 +115,19 @@ class InfectionMonkey:
self._singleton = SystemSingleton() self._singleton = SystemSingleton()
self._opts = self._get_arguments(args) self._opts = self._get_arguments(args)
self._server_strings = [str(s) for s in self._opts.servers]
self._agent_event_serializer_registry = self._setup_agent_event_serializers() self._agent_event_serializer_registry = self._setup_agent_event_serializers()
server, self._island_api_client = self._connect_to_island_api() server, self._island_api_client = self._connect_to_island_api()
self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(server) self._cmd_island_ip = server.ip
self._cmd_island_port = server.port
self._island_address = SocketAddress(self._cmd_island_ip, self._cmd_island_port) self._island_address = SocketAddress(self._cmd_island_ip, self._cmd_island_port)
self._control_client = ControlClient( self._control_client = ControlClient(
server_address=server, island_api_client=self._island_api_client server_address=str(server), island_api_client=self._island_api_client
) )
self._control_channel = ControlChannel(server, get_agent_id(), self._island_api_client) self._control_channel = ControlChannel(str(server), get_agent_id(), self._island_api_client)
self._register_agent(self._island_address) self._register_agent(self._island_address)
# TODO Refactor the telemetry messengers to accept control client # TODO Refactor the telemetry messengers to accept control client
@ -153,10 +154,10 @@ class InfectionMonkey:
return opts return opts
# TODO: By the time we finish 2292, _connect_to_island_api() may not need to return `server` # TODO: By the time we finish 2292, _connect_to_island_api() may not need to return `server`
def _connect_to_island_api(self) -> Tuple[Optional[str], Optional[IIslandAPIClient]]: def _connect_to_island_api(self) -> Tuple[Optional[SocketAddress], Optional[IIslandAPIClient]]:
logger.debug(f"Trying to wake up with servers: {', '.join(self._server_strings)}") logger.debug(f"Trying to wake up with servers: {', '.join(map(str, self._opts.servers))}")
server_clients = find_available_island_apis( server_clients = find_available_island_apis(
self._server_strings, HTTPIslandAPIClientFactory(self._agent_event_serializer_registry) self._opts.servers, HTTPIslandAPIClientFactory(self._agent_event_serializer_registry)
) )
server, island_api_client = self._select_server(server_clients) server, island_api_client = self._select_server(server_clients)
@ -165,13 +166,14 @@ class InfectionMonkey:
logger.info(f"Successfully connected to the island via {server}") logger.info(f"Successfully connected to the island via {server}")
else: else:
raise Exception( raise Exception(
f"Failed to connect to the island via any known servers: {self._server_strings}" "Failed to connect to the island via any known servers: "
f"[{', '.join(map(str, self._opts.servers))}]"
) )
# NOTE: Since we pass the address for each of our interfaces to the exploited # NOTE: Since we pass the address for each of our interfaces to the exploited
# machines, is it possible for a machine to unintentionally unregister itself from the # machines, is it possible for a machine to unintentionally unregister itself from the
# relay if it is able to connect to the relay over multiple interfaces? # relay if it is able to connect to the relay over multiple interfaces?
servers_to_close = (s for s in self._server_strings if s != server and server_clients[s]) servers_to_close = (s for s in self._opts.servers if s != server and server_clients[s])
send_remove_from_waitlist_control_message_to_relays(servers_to_close) send_remove_from_waitlist_control_message_to_relays(servers_to_close)
return server, island_api_client return server, island_api_client
@ -189,10 +191,10 @@ class InfectionMonkey:
self._island_api_client.register_agent(agent_registration_data) self._island_api_client.register_agent(agent_registration_data)
def _select_server( def _select_server(
self, server_clients: Mapping[str, Optional[IIslandAPIClient]] self, server_clients: IslandAPISearchResults
) -> Tuple[Optional[str], Optional[IIslandAPIClient]]: ) -> Tuple[Optional[SocketAddress], Optional[IIslandAPIClient]]:
for server in self._server_strings: for server in self._opts.servers:
if server_clients[server]: if server_clients[server] is not None:
return server, server_clients[server] return server, server_clients[server]
return None, None return None, None
@ -260,8 +262,9 @@ class InfectionMonkey:
return agent_event_serializer_registry return agent_event_serializer_registry
def _build_server_list(self, relay_port: int): def _build_server_list(self, relay_port: int):
my_servers = map(str, self._opts.servers)
relay_servers = [f"{ip}:{relay_port}" for ip in get_my_ip_addresses()] relay_servers = [f"{ip}:{relay_port}" for ip in get_my_ip_addresses()]
return self._server_strings + relay_servers return my_servers + relay_servers
def _build_master(self, relay_port: int): def _build_master(self, relay_port: int):
servers = self._build_server_list(relay_port) servers = self._build_server_list(relay_port)
@ -276,7 +279,6 @@ class InfectionMonkey:
self._subscribe_events( self._subscribe_events(
event_queue, event_queue,
propagation_credentials_repository, propagation_credentials_repository,
self._control_client.server_address,
self._agent_event_serializer_registry, self._agent_event_serializer_registry,
) )
@ -303,7 +305,6 @@ class InfectionMonkey:
self, self,
event_queue: IAgentEventQueue, event_queue: IAgentEventQueue,
propagation_credentials_repository: IPropagationCredentialsRepository, propagation_credentials_repository: IPropagationCredentialsRepository,
server_address: str,
agent_event_serializer_registry: AgentEventSerializerRegistry, agent_event_serializer_registry: AgentEventSerializerRegistry,
): ):
event_queue.subscribe_type( event_queue.subscribe_type(

View File

@ -25,11 +25,12 @@ logger = logging.getLogger(__name__)
# practical purposes. Revisit this if it's not. # practical purposes. Revisit this if it's not.
NUM_FIND_SERVER_WORKERS = 32 NUM_FIND_SERVER_WORKERS = 32
IslandAPISearchResults = Dict[str, Optional[IIslandAPIClient]]
IslandAPISearchResults = Dict[SocketAddress, Optional[IIslandAPIClient]]
def find_available_island_apis( def find_available_island_apis(
servers: Iterable[str], island_api_client_factory: AbstractIslandAPIClientFactory servers: Iterable[SocketAddress], island_api_client_factory: AbstractIslandAPIClientFactory
) -> IslandAPISearchResults: ) -> IslandAPISearchResults:
server_list = list(servers) server_list = list(servers)
server_iterator = ThreadSafeIterator(server_list.__iter__()) server_iterator = ThreadSafeIterator(server_list.__iter__())
@ -46,7 +47,7 @@ def find_available_island_apis(
def _find_island_server( def _find_island_server(
servers: Iterator[str], servers: Iterator[SocketAddress],
server_results: IslandAPISearchResults, server_results: IslandAPISearchResults,
island_api_client_factory: AbstractIslandAPIClientFactory, island_api_client_factory: AbstractIslandAPIClientFactory,
): ):
@ -56,7 +57,7 @@ def _find_island_server(
def _check_if_island_server( def _check_if_island_server(
server: str, island_api_client_factory: AbstractIslandAPIClientFactory server: SocketAddress, island_api_client_factory: AbstractIslandAPIClientFactory
) -> Optional[IIslandAPIClient]: ) -> Optional[IIslandAPIClient]:
logger.debug(f"Trying to connect to server: {server}") logger.debug(f"Trying to connect to server: {server}")
@ -77,13 +78,12 @@ def _check_if_island_server(
return None return None
def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[str]): def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[SocketAddress]):
for i, server in enumerate(servers, start=1): for i, server in enumerate(servers, start=1):
server_address = SocketAddress.from_string(server)
t = create_daemon_thread( t = create_daemon_thread(
target=notify_disconnect, target=notify_disconnect,
name=f"SendRemoveFromWaitlistControlMessageToRelaysThread-{i:02d}", name=f"SendRemoveFromWaitlistControlMessageToRelaysThread-{i:02d}",
args=(server_address,), args=(server,),
) )
t.start() t.start()

View File

@ -2,6 +2,7 @@ import pytest
import requests_mock import requests_mock
from common.agent_event_serializers import AgentEventSerializerRegistry from common.agent_event_serializers import AgentEventSerializerRegistry
from common.types import SocketAddress
from infection_monkey.island_api_client import ( from infection_monkey.island_api_client import (
HTTPIslandAPIClientFactory, HTTPIslandAPIClientFactory,
IIslandAPIClient, IIslandAPIClient,
@ -9,10 +10,10 @@ from infection_monkey.island_api_client import (
) )
from infection_monkey.network.relay.utils import find_available_island_apis from infection_monkey.network.relay.utils import find_available_island_apis
SERVER_1 = "1.1.1.1:12312" SERVER_1 = SocketAddress(ip="1.1.1.1", port=12312)
SERVER_2 = "2.2.2.2:4321" SERVER_2 = SocketAddress(ip="2.2.2.2", port=4321)
SERVER_3 = "3.3.3.3:3142" SERVER_3 = SocketAddress(ip="3.3.3.3", port=3142)
SERVER_4 = "4.4.4.4:5000" SERVER_4 = SocketAddress(ip="4.4.4.4", port=5000)
servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4] servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4]