forked from p15670423/monkey
Agent: Add IIslandAPIClient.connect()
Different clients may have different dependencies in their constructors. Use connect() instead of __init__() to connect to the Island. Add an AbstractIslandAPIClientFactory and HTTPIslandAPIClientFactory to facilitate this.
This commit is contained in:
parent
e9433ad23b
commit
9807c23571
|
@ -7,4 +7,4 @@ from .island_api_client_errors import (
|
|||
)
|
||||
from .i_island_api_client import IIslandAPIClient
|
||||
from .abstract_island_api_client_factory import AbstractIslandAPIClientFactory
|
||||
from .http_island_api_client import HTTPIslandAPIClient
|
||||
from .http_island_api_client import HTTPIslandAPIClient, HTTPIslandAPIClientFactory
|
||||
|
|
|
@ -4,11 +4,12 @@ from typing import List, Sequence
|
|||
|
||||
import requests
|
||||
|
||||
from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable
|
||||
from common.agent_events import AbstractAgentEvent
|
||||
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT
|
||||
from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable
|
||||
|
||||
from . import (
|
||||
AbstractIslandAPIClientFactory,
|
||||
IIslandAPIClient,
|
||||
IslandAPIConnectionError,
|
||||
IslandAPIError,
|
||||
|
@ -49,11 +50,16 @@ class HTTPIslandAPIClient(IIslandAPIClient):
|
|||
A client for the Island's HTTP API
|
||||
"""
|
||||
|
||||
@handle_island_errors
|
||||
def __init__(
|
||||
self,
|
||||
agent_event_serializer_registry: AgentEventSerializerRegistry,
|
||||
):
|
||||
self._agent_event_serializer_registry = agent_event_serializer_registry
|
||||
|
||||
@handle_island_errors
|
||||
def connect(
|
||||
self,
|
||||
island_server: str,
|
||||
agent_event_serializer_registry: AgentEventSerializerRegistry = None,
|
||||
):
|
||||
response = requests.get( # noqa: DUO123
|
||||
f"https://{island_server}/api?action=is-up",
|
||||
|
@ -65,8 +71,6 @@ class HTTPIslandAPIClient(IIslandAPIClient):
|
|||
self._island_server = island_server
|
||||
self._api_url = f"https://{self._island_server}/api"
|
||||
|
||||
self._agent_event_serializer_registry = agent_event_serializer_registry
|
||||
|
||||
@handle_island_errors
|
||||
def send_log(self, log_contents: str):
|
||||
response = requests.post( # noqa: DUO123
|
||||
|
@ -110,3 +114,14 @@ class HTTPIslandAPIClient(IIslandAPIClient):
|
|||
raise IslandAPIRequestError(err)
|
||||
|
||||
return serialized_events
|
||||
|
||||
|
||||
class HTTPIslandAPIClientFactory(AbstractIslandAPIClientFactory):
|
||||
def __init__(
|
||||
self,
|
||||
agent_event_serializer_registry: AgentEventSerializerRegistry = None,
|
||||
):
|
||||
self._agent_event_serializer_registry = agent_event_serializer_registry
|
||||
|
||||
def create_island_api_client(self):
|
||||
return HTTPIslandAPIClient(self._agent_event_serializer_registry)
|
||||
|
|
|
@ -10,9 +10,9 @@ class IIslandAPIClient(ABC):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, island_server: str):
|
||||
def connect(self, island_server: str):
|
||||
"""
|
||||
Construct an island API client and connect it to the island
|
||||
Connectto the island's API
|
||||
|
||||
:param island_server: The socket address of the API
|
||||
:raises IslandAPIConnectionError: If the client cannot successfully connect to the island
|
||||
|
|
|
@ -45,7 +45,7 @@ from infection_monkey.exploit.sshexec import SSHExploiter
|
|||
from infection_monkey.exploit.wmiexec import WmiExploiter
|
||||
from infection_monkey.exploit.zerologon import ZerologonExploiter
|
||||
from infection_monkey.i_puppet import IPuppet, PluginType
|
||||
from infection_monkey.island_api_client import IIslandAPIClient
|
||||
from infection_monkey.island_api_client import HTTPIslandAPIClientFactory, IIslandAPIClient
|
||||
from infection_monkey.master import AutomatedMaster
|
||||
from infection_monkey.master.control_channel import ControlChannel
|
||||
from infection_monkey.model import VictimHostFactory
|
||||
|
@ -114,12 +114,12 @@ class InfectionMonkey:
|
|||
self._agent_event_serializer_registry = self._setup_agent_event_serializers()
|
||||
|
||||
# TODO: Revisit variable names
|
||||
server, island_api_client = self._connect_to_island_api()
|
||||
server, self._island_api_client = self._connect_to_island_api()
|
||||
# TODO: `address_to_port()` should return the port as an integer.
|
||||
self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(server)
|
||||
self._cmd_island_port = int(self._cmd_island_port)
|
||||
self._control_client = ControlClient(
|
||||
server_address=server, island_api_client=island_api_client
|
||||
server_address=server, island_api_client=self._island_api_client
|
||||
)
|
||||
|
||||
# TODO Refactor the telemetry messengers to accept control client
|
||||
|
@ -145,7 +145,9 @@ class InfectionMonkey:
|
|||
# TODO: By the time we finish 2292, _connect_to_island_api() may not need to return `server`
|
||||
def _connect_to_island_api(self) -> Tuple[str, IIslandAPIClient]:
|
||||
logger.debug(f"Trying to wake up with servers: {', '.join(self._opts.servers)}")
|
||||
server_clients = find_available_island_apis(self._opts.servers)
|
||||
server_clients = find_available_island_apis(
|
||||
self._opts.servers, HTTPIslandAPIClientFactory(self._agent_event_serializer_registry)
|
||||
)
|
||||
|
||||
server, island_api_client = self._select_server(server_clients)
|
||||
|
||||
|
@ -289,7 +291,7 @@ class InfectionMonkey:
|
|||
),
|
||||
)
|
||||
event_queue.subscribe_all_events(
|
||||
AgentEventForwarder(self.island_api_client, agent_event_serializer_registry).send_event
|
||||
AgentEventForwarder(self._island_api_client, agent_event_serializer_registry).send_event
|
||||
)
|
||||
|
||||
def _build_puppet(
|
||||
|
|
|
@ -7,7 +7,7 @@ from typing import Dict, Iterable, Iterator, Mapping, MutableMapping, Optional,
|
|||
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT
|
||||
from common.network.network_utils import address_to_ip_port
|
||||
from infection_monkey.island_api_client import (
|
||||
HTTPIslandAPIClient,
|
||||
AbstractIslandAPIClientFactory,
|
||||
IIslandAPIClient,
|
||||
IslandAPIConnectionError,
|
||||
IslandAPIError,
|
||||
|
@ -27,7 +27,9 @@ logger = logging.getLogger(__name__)
|
|||
NUM_FIND_SERVER_WORKERS = 32
|
||||
|
||||
|
||||
def find_available_island_apis(servers: Iterable[str]) -> Mapping[str, Optional[IIslandAPIClient]]:
|
||||
def find_available_island_apis(
|
||||
servers: Iterable[str], island_api_client_factory: AbstractIslandAPIClientFactory
|
||||
) -> Mapping[str, Optional[IIslandAPIClient]]:
|
||||
server_list = list(servers)
|
||||
server_iterator = ThreadSafeIterator(server_list.__iter__())
|
||||
server_results: Dict[str, Tuple[bool, IIslandAPIClient]] = {}
|
||||
|
@ -35,7 +37,7 @@ def find_available_island_apis(servers: Iterable[str]) -> Mapping[str, Optional[
|
|||
run_worker_threads(
|
||||
_find_island_server,
|
||||
"FindIslandServer",
|
||||
args=(server_iterator, server_results),
|
||||
args=(server_iterator, server_results, island_api_client_factory),
|
||||
num_workers=NUM_FIND_SERVER_WORKERS,
|
||||
)
|
||||
|
||||
|
@ -43,18 +45,25 @@ def find_available_island_apis(servers: Iterable[str]) -> Mapping[str, Optional[
|
|||
|
||||
|
||||
def _find_island_server(
|
||||
servers: Iterator[str], server_status: MutableMapping[str, Optional[IIslandAPIClient]]
|
||||
servers: Iterator[str],
|
||||
server_status: MutableMapping[str, Optional[IIslandAPIClient]],
|
||||
island_api_client_factory: AbstractIslandAPIClientFactory,
|
||||
):
|
||||
with suppress(StopIteration):
|
||||
server = next(servers)
|
||||
server_status[server] = _check_if_island_server(server)
|
||||
server_status[server] = _check_if_island_server(server, island_api_client_factory)
|
||||
|
||||
|
||||
def _check_if_island_server(server: str) -> IIslandAPIClient:
|
||||
def _check_if_island_server(
|
||||
server: str, island_api_client_factory: AbstractIslandAPIClientFactory
|
||||
) -> IIslandAPIClient:
|
||||
logger.debug(f"Trying to connect to server: {server}")
|
||||
|
||||
try:
|
||||
return HTTPIslandAPIClient(server)
|
||||
client = island_api_client_factory.create_island_api_client()
|
||||
client.connect(server)
|
||||
|
||||
return client
|
||||
except IslandAPIConnectionError as err:
|
||||
logger.error(f"Unable to connect to server/relay {server}: {err}")
|
||||
except IslandAPITimeoutError as err:
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
import pytest
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
import requests_mock
|
||||
from common.agent_events import AbstractAgentEvent
|
||||
|
||||
from common.agent_event_serializers import (
|
||||
AgentEventSerializerRegistry,
|
||||
PydanticAgentEventSerializer,
|
||||
)
|
||||
|
||||
from common.agent_events import AbstractAgentEvent
|
||||
from infection_monkey.island_api_client import (
|
||||
HTTPIslandAPIClient,
|
||||
IslandAPIConnectionError,
|
||||
|
@ -28,6 +29,32 @@ ISLAND_SEND_EVENTS_URI = f"https://{SERVER}/api/agent-events"
|
|||
AGENT_ID = UUID("80988359-a1cd-42a2-9b47-5b94b37cd673")
|
||||
|
||||
|
||||
class Event1(AbstractAgentEvent):
|
||||
a: int
|
||||
|
||||
|
||||
class Event2(AbstractAgentEvent):
|
||||
b: str
|
||||
|
||||
|
||||
class Event3(AbstractAgentEvent):
|
||||
c: int
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent_event_serializer_registry():
|
||||
agent_event_serializer_registry = AgentEventSerializerRegistry()
|
||||
agent_event_serializer_registry[Event1] = PydanticAgentEventSerializer(Event1)
|
||||
agent_event_serializer_registry[Event2] = PydanticAgentEventSerializer(Event2)
|
||||
|
||||
return agent_event_serializer_registry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def island_api_client(agent_event_serializer_registry):
|
||||
return HTTPIslandAPIClient(agent_event_serializer_registry)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"actual_error, expected_error",
|
||||
[
|
||||
|
@ -36,12 +63,12 @@ AGENT_ID = UUID("80988359-a1cd-42a2-9b47-5b94b37cd673")
|
|||
(Exception, IslandAPIError),
|
||||
],
|
||||
)
|
||||
def test_island_api_client(actual_error, expected_error):
|
||||
def test_island_api_client(island_api_client, actual_error, expected_error):
|
||||
with requests_mock.Mocker() as m:
|
||||
m.get(ISLAND_URI, exc=actual_error)
|
||||
|
||||
with pytest.raises(expected_error):
|
||||
HTTPIslandAPIClient(SERVER)
|
||||
island_api_client.connect(SERVER)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -51,12 +78,12 @@ def test_island_api_client(actual_error, expected_error):
|
|||
(501, IslandAPIRequestFailedError),
|
||||
],
|
||||
)
|
||||
def test_island_api_client__status_code(status_code, expected_error):
|
||||
def test_island_api_client__status_code(island_api_client, status_code, expected_error):
|
||||
with requests_mock.Mocker() as m:
|
||||
m.get(ISLAND_URI, status_code=status_code)
|
||||
|
||||
with pytest.raises(expected_error):
|
||||
HTTPIslandAPIClient(SERVER)
|
||||
island_api_client.connect(SERVER)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -67,10 +94,10 @@ def test_island_api_client__status_code(status_code, expected_error):
|
|||
(Exception, IslandAPIError),
|
||||
],
|
||||
)
|
||||
def test_island_api_client__send_log(actual_error, expected_error):
|
||||
def test_island_api_client__send_log(island_api_client, actual_error, expected_error):
|
||||
with requests_mock.Mocker() as m:
|
||||
m.get(ISLAND_URI)
|
||||
island_api_client = HTTPIslandAPIClient(SERVER)
|
||||
island_api_client.connect(SERVER)
|
||||
|
||||
with pytest.raises(expected_error):
|
||||
m.post(ISLAND_SEND_LOG_URI, exc=actual_error)
|
||||
|
@ -84,10 +111,10 @@ def test_island_api_client__send_log(actual_error, expected_error):
|
|||
(501, IslandAPIRequestFailedError),
|
||||
],
|
||||
)
|
||||
def test_island_api_client_send_log__status_code(status_code, expected_error):
|
||||
def test_island_api_client_send_log__status_code(island_api_client, status_code, expected_error):
|
||||
with requests_mock.Mocker() as m:
|
||||
m.get(ISLAND_URI)
|
||||
island_api_client = HTTPIslandAPIClient(SERVER)
|
||||
island_api_client.connect(SERVER)
|
||||
|
||||
with pytest.raises(expected_error):
|
||||
m.post(ISLAND_SEND_LOG_URI, status_code=status_code)
|
||||
|
@ -102,10 +129,10 @@ def test_island_api_client_send_log__status_code(status_code, expected_error):
|
|||
(Exception, IslandAPIError),
|
||||
],
|
||||
)
|
||||
def test_island_api_client__get_pba_file(actual_error, expected_error):
|
||||
def test_island_api_client__get_pba_file(island_api_client, actual_error, expected_error):
|
||||
with requests_mock.Mocker() as m:
|
||||
m.get(ISLAND_URI)
|
||||
island_api_client = HTTPIslandAPIClient(SERVER)
|
||||
island_api_client.connect(SERVER)
|
||||
|
||||
with pytest.raises(expected_error):
|
||||
m.get(ISLAND_GET_PBA_FILE_URI, exc=actual_error)
|
||||
|
@ -119,37 +146,19 @@ def test_island_api_client__get_pba_file(actual_error, expected_error):
|
|||
(501, IslandAPIRequestFailedError),
|
||||
],
|
||||
)
|
||||
def test_island_api_client_get_pba_file__status_code(status_code, expected_error):
|
||||
def test_island_api_client_get_pba_file__status_code(
|
||||
island_api_client, status_code, expected_error
|
||||
):
|
||||
with requests_mock.Mocker() as m:
|
||||
m.get(ISLAND_URI)
|
||||
island_api_client = HTTPIslandAPIClient(SERVER)
|
||||
island_api_client.connect(SERVER)
|
||||
|
||||
with pytest.raises(expected_error):
|
||||
m.get(ISLAND_GET_PBA_FILE_URI, status_code=status_code)
|
||||
island_api_client.get_pba_file(filename=PBA_FILE)
|
||||
|
||||
|
||||
class Event1(AbstractAgentEvent):
|
||||
a: int
|
||||
|
||||
|
||||
class Event2(AbstractAgentEvent):
|
||||
b: str
|
||||
|
||||
|
||||
class Event3(AbstractAgentEvent):
|
||||
c: int
|
||||
|
||||
@pytest.fixture
|
||||
def agent_event_serializer_registry():
|
||||
agent_event_serializer_registry = AgentEventSerializerRegistry()
|
||||
agent_event_serializer_registry[Event1] = PydanticAgentEventSerializer(Event1)
|
||||
agent_event_serializer_registry[Event2] = PydanticAgentEventSerializer(Event2)
|
||||
|
||||
return agent_event_serializer_registry
|
||||
|
||||
|
||||
def test_island_api_client_send_events__serialization(agent_event_serializer_registry):
|
||||
def test_island_api_client_send_events__serialization(island_api_client):
|
||||
events_to_send = [
|
||||
Event1(source=AGENT_ID, timestamp=0, a=1),
|
||||
Event2(source=AGENT_ID, timestamp=0, b="hello"),
|
||||
|
@ -176,16 +185,17 @@ def test_island_api_client_send_events__serialization(agent_event_serializer_reg
|
|||
with requests_mock.Mocker() as m:
|
||||
m.get(ISLAND_URI)
|
||||
m.post(ISLAND_SEND_EVENTS_URI)
|
||||
island_api_client = HTTPIslandAPIClient(SERVER, agent_event_serializer_registry)
|
||||
island_api_client.connect(SERVER)
|
||||
|
||||
island_api_client.send_events(events=events_to_send)
|
||||
|
||||
assert m.last_request.json() == expected_json
|
||||
|
||||
def test_island_api_client_send_events__serialization_failed(agent_event_serializer_registry):
|
||||
|
||||
def test_island_api_client_send_events__serialization_failed(island_api_client):
|
||||
with requests_mock.Mocker() as m:
|
||||
m.get(ISLAND_URI)
|
||||
island_api_client = HTTPIslandAPIClient(SERVER, agent_event_serializer_registry)
|
||||
island_api_client.connect(SERVER)
|
||||
|
||||
with pytest.raises(IslandAPIRequestError):
|
||||
m.post(ISLAND_SEND_EVENTS_URI)
|
||||
|
@ -200,12 +210,10 @@ def test_island_api_client_send_events__serialization_failed(agent_event_seriali
|
|||
(Exception, IslandAPIError),
|
||||
],
|
||||
)
|
||||
def test_island_api_client__send_events(
|
||||
actual_error, expected_error, agent_event_serializer_registry
|
||||
):
|
||||
def test_island_api_client__send_events(island_api_client, actual_error, expected_error):
|
||||
with requests_mock.Mocker() as m:
|
||||
m.get(ISLAND_URI)
|
||||
island_api_client = HTTPIslandAPIClient(SERVER, agent_event_serializer_registry)
|
||||
island_api_client.connect(SERVER)
|
||||
|
||||
with pytest.raises(expected_error):
|
||||
m.post(ISLAND_SEND_EVENTS_URI, exc=actual_error)
|
||||
|
@ -219,12 +227,10 @@ def test_island_api_client__send_events(
|
|||
(501, IslandAPIRequestFailedError),
|
||||
],
|
||||
)
|
||||
def test_island_api_client_send_events__status_code(
|
||||
status_code, expected_error, agent_event_serializer_registry
|
||||
):
|
||||
def test_island_api_client_send_events__status_code(island_api_client, status_code, expected_error):
|
||||
with requests_mock.Mocker() as m:
|
||||
m.get(ISLAND_URI)
|
||||
island_api_client = HTTPIslandAPIClient(SERVER, agent_event_serializer_registry)
|
||||
island_api_client.connect(SERVER)
|
||||
|
||||
with pytest.raises(expected_error):
|
||||
m.post(ISLAND_SEND_EVENTS_URI, status_code=status_code)
|
||||
|
|
|
@ -1,7 +1,12 @@
|
|||
import pytest
|
||||
import requests_mock
|
||||
|
||||
from infection_monkey.island_api_client import IIslandAPIClient, IslandAPIConnectionError
|
||||
from common.agent_event_serializers import AgentEventSerializerRegistry
|
||||
from infection_monkey.island_api_client import (
|
||||
HTTPIslandAPIClientFactory,
|
||||
IIslandAPIClient,
|
||||
IslandAPIConnectionError,
|
||||
)
|
||||
from infection_monkey.network.relay.utils import find_available_island_apis
|
||||
|
||||
SERVER_1 = "1.1.1.1:12312"
|
||||
|
@ -13,6 +18,11 @@ SERVER_4 = "4.4.4.4:5000"
|
|||
servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def island_api_client_factory():
|
||||
return HTTPIslandAPIClientFactory(AgentEventSerializerRegistry())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"expected_available_servers, server_response_pairs",
|
||||
[
|
||||
|
@ -24,12 +34,14 @@ servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4]
|
|||
),
|
||||
],
|
||||
)
|
||||
def test_find_available_island_apis(expected_available_servers, server_response_pairs):
|
||||
def test_find_available_island_apis(
|
||||
expected_available_servers, server_response_pairs, island_api_client_factory
|
||||
):
|
||||
with requests_mock.Mocker() as mock:
|
||||
for server, response in server_response_pairs:
|
||||
mock.get(f"https://{server}/api?action=is-up", **response)
|
||||
|
||||
available_apis = find_available_island_apis(servers)
|
||||
available_apis = find_available_island_apis(servers, island_api_client_factory)
|
||||
|
||||
assert len(available_apis) == len(server_response_pairs)
|
||||
|
||||
|
@ -40,14 +52,14 @@ def test_find_available_island_apis(expected_available_servers, server_response_
|
|||
assert island_api_client is None
|
||||
|
||||
|
||||
def test_find_available_island_apis__multiple_successes():
|
||||
def test_find_available_island_apis__multiple_successes(island_api_client_factory):
|
||||
available_servers = [SERVER_2, SERVER_3]
|
||||
with requests_mock.Mocker() as mock:
|
||||
mock.get(f"https://{SERVER_1}/api?action=is-up", exc=IslandAPIConnectionError)
|
||||
for server in available_servers:
|
||||
mock.get(f"https://{server}/api?action=is-up", text="")
|
||||
|
||||
available_apis = find_available_island_apis(servers)
|
||||
available_apis = find_available_island_apis(servers, island_api_client_factory)
|
||||
|
||||
assert available_apis[SERVER_1] is None
|
||||
assert available_apis[SERVER_4] is None
|
||||
|
|
Loading…
Reference in New Issue