forked from p15670423/monkey
Agent: Modify find_server to return tuple of server and IIslandAPIClient
This commit is contained in:
parent
f4b47f8238
commit
bc19b5ea93
|
@ -5,7 +5,7 @@ import subprocess
|
|||
import sys
|
||||
from ipaddress import IPv4Address, IPv4Interface
|
||||
from pathlib import Path, WindowsPath
|
||||
from typing import List
|
||||
from typing import List, Tuple
|
||||
|
||||
from pubsub.core import Publisher
|
||||
|
||||
|
@ -45,6 +45,7 @@ from infection_monkey.exploit.sshexec import SSHExploiter
|
|||
from infection_monkey.exploit.wmiexec import WmiExploiter
|
||||
from infection_monkey.exploit.zerologon import ZerologonExploiter
|
||||
from infection_monkey.i_puppet import IPuppet, PluginType
|
||||
from infection_monkey.island_api_client import IIslandAPIClient
|
||||
from infection_monkey.master import AutomatedMaster
|
||||
from infection_monkey.master.control_channel import ControlChannel
|
||||
from infection_monkey.model import VictimHostFactory
|
||||
|
@ -110,7 +111,7 @@ class InfectionMonkey:
|
|||
self._opts = self._get_arguments(args)
|
||||
|
||||
# TODO: Revisit variable names
|
||||
server = self._get_server()
|
||||
server, island_api_client = self._get_server()
|
||||
# TODO: `address_to_port()` should return the port as an integer.
|
||||
self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(server)
|
||||
self._cmd_island_port = int(self._cmd_island_port)
|
||||
|
@ -136,9 +137,9 @@ class InfectionMonkey:
|
|||
|
||||
return opts
|
||||
|
||||
def _get_server(self):
|
||||
def _get_server(self) -> Tuple[str, IIslandAPIClient]:
|
||||
logger.debug(f"Trying to wake up with servers: {', '.join(self._opts.servers)}")
|
||||
server = find_server(self._opts.servers)
|
||||
server, island_api_client = find_server(self._opts.servers)
|
||||
if server:
|
||||
logger.info(f"Successfully connected to the island via {server}")
|
||||
else:
|
||||
|
@ -152,7 +153,7 @@ class InfectionMonkey:
|
|||
servers_to_close = (s for s in self._opts.servers if s != server)
|
||||
send_remove_from_waitlist_control_message_to_relays(servers_to_close)
|
||||
|
||||
return server
|
||||
return server, island_api_client
|
||||
|
||||
@staticmethod
|
||||
def _log_arguments(args):
|
||||
|
|
|
@ -2,11 +2,12 @@ import logging
|
|||
import socket
|
||||
from contextlib import suppress
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Iterable, Iterator, MutableMapping, Optional
|
||||
from typing import Dict, Iterable, Iterator, MutableMapping, Optional, Tuple
|
||||
|
||||
from common.network.network_utils import address_to_ip_port
|
||||
from infection_monkey.island_api_client import (
|
||||
HTTPIslandAPIClient,
|
||||
IIslandAPIClient,
|
||||
IslandAPIConnectionError,
|
||||
IslandAPIError,
|
||||
IslandAPITimeoutError,
|
||||
|
@ -25,10 +26,10 @@ logger = logging.getLogger(__name__)
|
|||
NUM_FIND_SERVER_WORKERS = 32
|
||||
|
||||
|
||||
def find_server(servers: Iterable[str]) -> Optional[str]:
|
||||
def find_server(servers: Iterable[str]) -> Tuple[Optional[str], Optional[IIslandAPIClient]]:
|
||||
server_list = list(servers)
|
||||
server_iterator = ThreadSafeIterator(server_list.__iter__())
|
||||
server_results: Dict[str, bool] = {}
|
||||
server_results: Dict[str, Tuple[bool, IIslandAPIClient]] = {}
|
||||
|
||||
run_worker_threads(
|
||||
_find_island_server,
|
||||
|
@ -39,24 +40,27 @@ def find_server(servers: Iterable[str]) -> Optional[str]:
|
|||
|
||||
for server in server_list:
|
||||
if server_results[server]:
|
||||
return server
|
||||
island_api_client = server_results[server]
|
||||
return server, island_api_client
|
||||
|
||||
return None
|
||||
return (None, None)
|
||||
|
||||
|
||||
def _find_island_server(servers: Iterator[str], server_status: MutableMapping[str, bool]):
|
||||
def _find_island_server(
|
||||
servers: Iterator[str], server_status: MutableMapping[str, Optional[IIslandAPIClient]]
|
||||
):
|
||||
with suppress(StopIteration):
|
||||
server = next(servers)
|
||||
server_status[server] = _check_if_island_server(server)
|
||||
|
||||
|
||||
def _check_if_island_server(server: str) -> bool:
|
||||
def _check_if_island_server(server: str) -> IIslandAPIClient:
|
||||
logger.debug(f"Trying to connect to server: {server}")
|
||||
|
||||
try:
|
||||
HTTPIslandAPIClient(server)
|
||||
island_api_client = HTTPIslandAPIClient(server)
|
||||
|
||||
return True
|
||||
return island_api_client
|
||||
except IslandAPIConnectionError as err:
|
||||
logger.error(f"Unable to connect to server/relay {server}: {err}")
|
||||
except IslandAPITimeoutError as err:
|
||||
|
@ -66,7 +70,7 @@ def _check_if_island_server(server: str) -> bool:
|
|||
f"Exception encountered when trying to connect to server/relay {server}: {err}"
|
||||
)
|
||||
|
||||
return False
|
||||
return None
|
||||
|
||||
|
||||
def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[str]):
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import pytest
|
||||
import requests_mock
|
||||
|
||||
from infection_monkey.island_api_client import IslandAPIConnectionError
|
||||
from infection_monkey.island_api_client import IIslandAPIClient, IslandAPIConnectionError
|
||||
from infection_monkey.network.relay.utils import find_server
|
||||
|
||||
SERVER_1 = "1.1.1.1:12312"
|
||||
|
@ -14,7 +14,7 @@ servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4]
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"expected_server,server_response_pairs",
|
||||
"expected_server, server_response_pairs",
|
||||
[
|
||||
(None, [(server, {"exc": IslandAPIConnectionError}) for server in servers]),
|
||||
(
|
||||
|
@ -29,7 +29,9 @@ def test_find_server(expected_server, server_response_pairs):
|
|||
for server, response in server_response_pairs:
|
||||
mock.get(f"https://{server}/api?action=is-up", **response)
|
||||
|
||||
assert find_server(servers) is expected_server
|
||||
actual_server, _ = find_server(servers)
|
||||
|
||||
assert actual_server is expected_server
|
||||
|
||||
|
||||
def test_find_server__multiple_successes():
|
||||
|
@ -39,4 +41,7 @@ def test_find_server__multiple_successes():
|
|||
mock.get(f"https://{SERVER_3}/api?action=is-up", text="")
|
||||
mock.get(f"https://{SERVER_4}/api?action=is-up", text="")
|
||||
|
||||
assert find_server(servers) == SERVER_2
|
||||
actual_server, actual_island_api_client = find_server(servers)
|
||||
|
||||
assert actual_server == SERVER_2
|
||||
assert isinstance(actual_island_api_client, IIslandAPIClient)
|
||||
|
|
Loading…
Reference in New Issue