diff --git a/monkey/common/types.py b/monkey/common/types.py index 51353d293..5f86d5060 100644 --- a/monkey/common/types.py +++ b/monkey/common/types.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from ipaddress import IPv4Address from uuid import UUID @@ -5,6 +7,7 @@ from pydantic import PositiveInt, conint from typing_extensions import TypeAlias from common.base_models import InfectionMonkeyBaseModel +from common.network.network_utils import address_to_ip_port AgentID: TypeAlias = UUID HardwareID: TypeAlias = PositiveInt @@ -15,5 +18,19 @@ class SocketAddress(InfectionMonkeyBaseModel): ip: IPv4Address port: conint(ge=1, le=65535) # type: ignore[valid-type] + @classmethod + def from_string(cls, address_str: str) -> SocketAddress: + """ + Parse a SocketAddress object from a string + + :param address_str: A string of ip:port + :raises ValueError: If the string is not a valid ip:port + :return: SocketAddress with the IP and port + """ + ip, port = address_to_ip_port(address_str) + if port is None: + raise ValueError("SocketAddress requires a port") + return SocketAddress(ip=IPv4Address(ip), port=int(port)) + def __str__(self): return f"{self.ip}:{self.port}" diff --git a/monkey/infection_monkey/monkey.py b/monkey/infection_monkey/monkey.py index b3e2c1d4c..e86bdbee1 100644 --- a/monkey/infection_monkey/monkey.py +++ b/monkey/infection_monkey/monkey.py @@ -114,6 +114,7 @@ class InfectionMonkey: self._singleton = SystemSingleton() self._opts = self._get_arguments(args) + self._server_strings = [str(s) for s in self._opts.servers] self._agent_event_serializer_registry = self._setup_agent_event_serializers() @@ -144,7 +145,11 @@ class InfectionMonkey: def _get_arguments(args): arg_parser = argparse.ArgumentParser() arg_parser.add_argument("-p", "--parent") - arg_parser.add_argument("-s", "--servers", type=lambda arg: arg.strip().split(",")) + arg_parser.add_argument( + "-s", + "--servers", + type=lambda arg: [SocketAddress.from_string(s) for s in arg.strip().split(",")], + ) arg_parser.add_argument("-d", "--depth", type=positive_int, default=0) opts = arg_parser.parse_args(args) InfectionMonkey._log_arguments(opts) @@ -153,9 +158,9 @@ class InfectionMonkey: # TODO: By the time we finish 2292, _connect_to_island_api() may not need to return `server` def _connect_to_island_api(self) -> Tuple[Optional[str], Optional[IIslandAPIClient]]: - logger.debug(f"Trying to wake up with servers: {', '.join(self._opts.servers)}") + logger.debug(f"Trying to wake up with servers: {', '.join(self._server_strings)}") server_clients = find_available_island_apis( - self._opts.servers, HTTPIslandAPIClientFactory(self._agent_event_serializer_registry) + self._server_strings, HTTPIslandAPIClientFactory(self._agent_event_serializer_registry) ) server, island_api_client = self._select_server(server_clients) @@ -164,13 +169,13 @@ class InfectionMonkey: logger.info(f"Successfully connected to the island via {server}") else: raise Exception( - f"Failed to connect to the island via any known servers: {self._opts.servers}" + f"Failed to connect to the island via any known servers: {self._server_strings}" ) # NOTE: Since we pass the address for each of our interfaces to the exploited # machines, is it possible for a machine to unintentionally unregister itself from the # relay if it is able to connect to the relay over multiple interfaces? - servers_to_close = (s for s in self._opts.servers if s != server and server_clients[s]) + servers_to_close = (s for s in self._server_strings if s != server and server_clients[s]) send_remove_from_waitlist_control_message_to_relays(servers_to_close) return server, island_api_client @@ -190,7 +195,7 @@ class InfectionMonkey: def _select_server( self, server_clients: Mapping[str, Optional[IIslandAPIClient]] ) -> Tuple[Optional[str], Optional[IIslandAPIClient]]: - for server in self._opts.servers: + for server in self._server_strings: if server_clients[server]: return server, server_clients[server] @@ -198,7 +203,7 @@ class InfectionMonkey: @staticmethod def _log_arguments(args): - arg_string = " ".join([f"{key}: {value}" for key, value in vars(args).items()]) + arg_string = ", ".join([f"{key}: {value}" for key, value in vars(args).items()]) logger.info(f"Monkey started with arguments: {arg_string}") def start(self): @@ -260,7 +265,7 @@ class InfectionMonkey: def _build_server_list(self, relay_port: int): relay_servers = [f"{ip}:{relay_port}" for ip in get_my_ip_addresses()] - return self._opts.servers + relay_servers + return self._server_strings + relay_servers def _build_master(self, relay_port: int): servers = self._build_server_list(relay_port) diff --git a/monkey/tests/unit_tests/common/test_socket_address.py b/monkey/tests/unit_tests/common/test_socket_address.py new file mode 100644 index 000000000..6c3b2d448 --- /dev/null +++ b/monkey/tests/unit_tests/common/test_socket_address.py @@ -0,0 +1,34 @@ +import pytest + +from common.types import SocketAddress + +GOOD_IP = "192.168.1.1" +BAD_IP = "192.168.1.999" +GOOD_PORT = 1234 +BAD_PORT = 99999 + + +def test_socket_address__from_string(): + expected = SocketAddress(ip=GOOD_IP, port=GOOD_PORT) + + address = SocketAddress.from_string(f"{GOOD_IP}:{GOOD_PORT}") + + assert address == expected + + +@pytest.mark.parametrize( + "bad_address", + [ + "not an address", + ":", + GOOD_IP, + str(GOOD_PORT), + f"{GOOD_IP}:", + f":{GOOD_PORT}", + f"{BAD_IP}:{GOOD_PORT}", + f"{GOOD_IP}:{BAD_PORT}", + ], +) +def test_socket_address__from_string_raises(bad_address: str): + with pytest.raises(ValueError): + SocketAddress.from_string(bad_address)