commit
741d192eab
|
@ -3,18 +3,15 @@ import queue
|
||||||
import threading
|
import threading
|
||||||
from time import sleep
|
from time import sleep
|
||||||
|
|
||||||
import requests
|
from common.agent_event_serializers import AgentEventSerializerRegistry
|
||||||
|
|
||||||
from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable
|
|
||||||
from common.agent_events import AbstractAgentEvent
|
from common.agent_events import AbstractAgentEvent
|
||||||
from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT
|
from infection_monkey.island_api_client import IIslandAPIClient
|
||||||
from infection_monkey.utils.threading import create_daemon_thread
|
from infection_monkey.utils.threading import create_daemon_thread
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_TIME_PERIOD_SECONDS = 5
|
DEFAULT_TIME_PERIOD_SECONDS = 5
|
||||||
AGENT_EVENTS_API_URL = "https://%s/api/agent-events"
|
|
||||||
|
|
||||||
|
|
||||||
class AgentEventForwarder:
|
class AgentEventForwarder:
|
||||||
|
@ -23,27 +20,21 @@ class AgentEventForwarder:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, server_address: str, agent_event_serializer_registry: AgentEventSerializerRegistry
|
self,
|
||||||
|
island_api_client: IIslandAPIClient,
|
||||||
|
agent_event_serializer_registry: AgentEventSerializerRegistry,
|
||||||
):
|
):
|
||||||
self._server_address = server_address
|
|
||||||
self._agent_event_serializer_registry = agent_event_serializer_registry
|
self._agent_event_serializer_registry = agent_event_serializer_registry
|
||||||
|
|
||||||
self._batching_agent_event_forwarder = BatchingAgentEventForwarder(self._server_address)
|
self._batching_agent_event_forwarder = BatchingAgentEventForwarder(island_api_client)
|
||||||
self._batching_agent_event_forwarder.start()
|
self._batching_agent_event_forwarder.start()
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self._batching_agent_event_forwarder.stop()
|
self._batching_agent_event_forwarder.stop()
|
||||||
|
|
||||||
def send_event(self, event: AbstractAgentEvent):
|
def send_event(self, event: AbstractAgentEvent):
|
||||||
serialized_event = self._serialize_event(event)
|
self._batching_agent_event_forwarder.add_event_to_queue(event)
|
||||||
self._batching_agent_event_forwarder.add_event_to_queue(serialized_event)
|
logger.debug(f"Sending event of type {type(event).__name__} to the Island")
|
||||||
logger.debug(
|
|
||||||
f"Sending event of type {type(event).__name__} to the Island at {self._server_address}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _serialize_event(self, event: AbstractAgentEvent):
|
|
||||||
serializer = self._agent_event_serializer_registry[event.__class__]
|
|
||||||
return serializer.serialize(event)
|
|
||||||
|
|
||||||
|
|
||||||
class BatchingAgentEventForwarder:
|
class BatchingAgentEventForwarder:
|
||||||
|
@ -51,8 +42,10 @@ class BatchingAgentEventForwarder:
|
||||||
Handles the batching and sending of the Agent's events to the Island
|
Handles the batching and sending of the Agent's events to the Island
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, server_address: str, time_period: int = DEFAULT_TIME_PERIOD_SECONDS):
|
def __init__(
|
||||||
self._server_address = server_address
|
self, island_api_client: IIslandAPIClient, time_period: int = DEFAULT_TIME_PERIOD_SECONDS
|
||||||
|
):
|
||||||
|
self._island_api_client = island_api_client
|
||||||
self._time_period = time_period
|
self._time_period = time_period
|
||||||
|
|
||||||
self._queue: queue.Queue[AbstractAgentEvent] = queue.Queue()
|
self._queue: queue.Queue[AbstractAgentEvent] = queue.Queue()
|
||||||
|
@ -64,7 +57,7 @@ class BatchingAgentEventForwarder:
|
||||||
)
|
)
|
||||||
self._batch_and_send_thread.start()
|
self._batch_and_send_thread.start()
|
||||||
|
|
||||||
def add_event_to_queue(self, serialized_event: JSONSerializable):
|
def add_event_to_queue(self, serialized_event: AbstractAgentEvent):
|
||||||
self._queue.put(serialized_event)
|
self._queue.put(serialized_event)
|
||||||
|
|
||||||
def _manage_event_batches(self):
|
def _manage_event_batches(self):
|
||||||
|
@ -84,18 +77,10 @@ class BatchingAgentEventForwarder:
|
||||||
events.append(self._queue.get(block=False))
|
events.append(self._queue.get(block=False))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.debug(f"Sending Agent events to Island at {self._server_address}: {events}")
|
logger.debug(f"Sending Agent events to Island: {events}")
|
||||||
requests.post( # noqa: DUO123
|
self._island_api_client.send_events(events)
|
||||||
AGENT_EVENTS_API_URL % (self._server_address,),
|
except Exception as err:
|
||||||
json=events,
|
logger.warning(f"Exception caught when connecting to the Island: {err}")
|
||||||
verify=False,
|
|
||||||
timeout=MEDIUM_REQUEST_TIMEOUT,
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(
|
|
||||||
f"Exception caught when connecting to the Island at {self._server_address}"
|
|
||||||
f": {exc}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _send_remaining_events(self):
|
def _send_remaining_events(self):
|
||||||
self._send_events_to_island()
|
self._send_events_to_island()
|
||||||
|
|
|
@ -6,4 +6,5 @@ from .island_api_client_errors import (
|
||||||
IslandAPITimeoutError,
|
IslandAPITimeoutError,
|
||||||
)
|
)
|
||||||
from .i_island_api_client import IIslandAPIClient
|
from .i_island_api_client import IIslandAPIClient
|
||||||
from .http_island_api_client import HTTPIslandAPIClient
|
from .abstract_island_api_client_factory import AbstractIslandAPIClientFactory
|
||||||
|
from .http_island_api_client import HTTPIslandAPIClient, HTTPIslandAPIClientFactory
|
||||||
|
|
|
@ -0,0 +1,13 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from . import IIslandAPIClient
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractIslandAPIClientFactory(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def create_island_api_client(self) -> IIslandAPIClient:
|
||||||
|
"""
|
||||||
|
Create an IIslandAPIClient
|
||||||
|
|
||||||
|
:return: A concrete instance of an IIslandAPIClient
|
||||||
|
"""
|
|
@ -1,12 +1,16 @@
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
|
from typing import List, Sequence
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from common import OperatingSystem
|
from common import OperatingSystem
|
||||||
|
from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable
|
||||||
|
from common.agent_events import AbstractAgentEvent
|
||||||
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT
|
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT
|
||||||
|
|
||||||
from . import (
|
from . import (
|
||||||
|
AbstractIslandAPIClientFactory,
|
||||||
IIslandAPIClient,
|
IIslandAPIClient,
|
||||||
IslandAPIConnectionError,
|
IslandAPIConnectionError,
|
||||||
IslandAPIError,
|
IslandAPIError,
|
||||||
|
@ -34,6 +38,8 @@ def handle_island_errors(fn):
|
||||||
raise IslandAPIError(err)
|
raise IslandAPIError(err)
|
||||||
except TimeoutError as err:
|
except TimeoutError as err:
|
||||||
raise IslandAPITimeoutError(err)
|
raise IslandAPITimeoutError(err)
|
||||||
|
except IslandAPIError as err:
|
||||||
|
raise err
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
raise IslandAPIError(err)
|
raise IslandAPIError(err)
|
||||||
|
|
||||||
|
@ -45,8 +51,17 @@ class HTTPIslandAPIClient(IIslandAPIClient):
|
||||||
A client for the Island's HTTP API
|
A client for the Island's HTTP API
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
agent_event_serializer_registry: AgentEventSerializerRegistry,
|
||||||
|
):
|
||||||
|
self._agent_event_serializer_registry = agent_event_serializer_registry
|
||||||
|
|
||||||
@handle_island_errors
|
@handle_island_errors
|
||||||
def __init__(self, island_server: str):
|
def connect(
|
||||||
|
self,
|
||||||
|
island_server: str,
|
||||||
|
):
|
||||||
response = requests.get( # noqa: DUO123
|
response = requests.get( # noqa: DUO123
|
||||||
f"https://{island_server}/api?action=is-up",
|
f"https://{island_server}/api?action=is-up",
|
||||||
verify=False,
|
verify=False,
|
||||||
|
@ -89,3 +104,37 @@ class HTTPIslandAPIClient(IIslandAPIClient):
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
return response.content
|
return response.content
|
||||||
|
|
||||||
|
@handle_island_errors
|
||||||
|
def send_events(self, events: Sequence[JSONSerializable]):
|
||||||
|
response = requests.post( # noqa: DUO123
|
||||||
|
f"{self._api_url}/agent-events",
|
||||||
|
json=self._serialize_events(events),
|
||||||
|
verify=False,
|
||||||
|
timeout=MEDIUM_REQUEST_TIMEOUT,
|
||||||
|
)
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
def _serialize_events(self, events: Sequence[AbstractAgentEvent]) -> JSONSerializable:
|
||||||
|
serialized_events: List[JSONSerializable] = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
for e in events:
|
||||||
|
serializer = self._agent_event_serializer_registry[e.__class__]
|
||||||
|
serialized_events.append(serializer.serialize(e))
|
||||||
|
except Exception as err:
|
||||||
|
raise IslandAPIRequestError(err)
|
||||||
|
|
||||||
|
return serialized_events
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPIslandAPIClientFactory(AbstractIslandAPIClientFactory):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
agent_event_serializer_registry: AgentEventSerializerRegistry = None,
|
||||||
|
):
|
||||||
|
self._agent_event_serializer_registry = agent_event_serializer_registry
|
||||||
|
|
||||||
|
def create_island_api_client(self):
|
||||||
|
return HTTPIslandAPIClient(self._agent_event_serializer_registry)
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
|
from common.agent_events import AbstractAgentEvent
|
||||||
|
|
||||||
|
|
||||||
class IIslandAPIClient(ABC):
|
class IIslandAPIClient(ABC):
|
||||||
|
@ -8,9 +10,9 @@ class IIslandAPIClient(ABC):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(self, island_server: str):
|
def connect(self, island_server: str):
|
||||||
"""
|
"""
|
||||||
Construct an island API client and connect it to the island
|
Connectto the island's API
|
||||||
|
|
||||||
:param island_server: The socket address of the API
|
:param island_server: The socket address of the API
|
||||||
:raises IslandAPIConnectionError: If the client cannot successfully connect to the island
|
:raises IslandAPIConnectionError: If the client cannot successfully connect to the island
|
||||||
|
@ -71,4 +73,21 @@ class IIslandAPIClient(ABC):
|
||||||
:raises IslandAPITimeoutError: If a timeout occurs while attempting to connect to the island
|
:raises IslandAPITimeoutError: If a timeout occurs while attempting to connect to the island
|
||||||
:raises IslandAPIError: If an unexpected error occurs while attempting to retrieve the
|
:raises IslandAPIError: If an unexpected error occurs while attempting to retrieve the
|
||||||
agent binary
|
agent binary
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def send_events(self, events: Sequence[AbstractAgentEvent]):
|
||||||
|
"""
|
||||||
|
Send a sequence of agent events to the Island
|
||||||
|
|
||||||
|
:param events: A sequence of agent events
|
||||||
|
:raises IslandAPIConnectionError: If the client cannot successfully connect to the island
|
||||||
|
:raises IslandAPIRequestError: If an error occurs while attempting to connect to the
|
||||||
|
island due to an issue in the request sent from the client
|
||||||
|
:raises IslandAPIRequestFailedError: If an error occurs while attempting to connect to the
|
||||||
|
island due to an error on the server
|
||||||
|
:raises IslandAPITimeoutError: If a timeout occurs while attempting to connect to the island
|
||||||
|
:raises IslandAPIError: If an unexpected error occurs while attempting to send events to
|
||||||
|
the island
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -45,7 +45,7 @@ from infection_monkey.exploit.sshexec import SSHExploiter
|
||||||
from infection_monkey.exploit.wmiexec import WmiExploiter
|
from infection_monkey.exploit.wmiexec import WmiExploiter
|
||||||
from infection_monkey.exploit.zerologon import ZerologonExploiter
|
from infection_monkey.exploit.zerologon import ZerologonExploiter
|
||||||
from infection_monkey.i_puppet import IPuppet, PluginType
|
from infection_monkey.i_puppet import IPuppet, PluginType
|
||||||
from infection_monkey.island_api_client import IIslandAPIClient
|
from infection_monkey.island_api_client import HTTPIslandAPIClientFactory, IIslandAPIClient
|
||||||
from infection_monkey.master import AutomatedMaster
|
from infection_monkey.master import AutomatedMaster
|
||||||
from infection_monkey.master.control_channel import ControlChannel
|
from infection_monkey.master.control_channel import ControlChannel
|
||||||
from infection_monkey.model import VictimHostFactory
|
from infection_monkey.model import VictimHostFactory
|
||||||
|
@ -107,9 +107,12 @@ logging.getLogger("urllib3").setLevel(logging.INFO)
|
||||||
class InfectionMonkey:
|
class InfectionMonkey:
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
logger.info("Monkey is initializing...")
|
logger.info("Monkey is initializing...")
|
||||||
|
|
||||||
self._singleton = SystemSingleton()
|
self._singleton = SystemSingleton()
|
||||||
self._opts = self._get_arguments(args)
|
self._opts = self._get_arguments(args)
|
||||||
|
|
||||||
|
self._agent_event_serializer_registry = self._setup_agent_event_serializers()
|
||||||
|
|
||||||
server, self._island_api_client = self._connect_to_island_api()
|
server, self._island_api_client = self._connect_to_island_api()
|
||||||
# TODO: `address_to_port()` should return the port as an integer.
|
# TODO: `address_to_port()` should return the port as an integer.
|
||||||
self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(server)
|
self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(server)
|
||||||
|
@ -141,7 +144,9 @@ class InfectionMonkey:
|
||||||
# TODO: By the time we finish 2292, _connect_to_island_api() may not need to return `server`
|
# TODO: By the time we finish 2292, _connect_to_island_api() may not need to return `server`
|
||||||
def _connect_to_island_api(self) -> Tuple[str, IIslandAPIClient]:
|
def _connect_to_island_api(self) -> Tuple[str, IIslandAPIClient]:
|
||||||
logger.debug(f"Trying to wake up with servers: {', '.join(self._opts.servers)}")
|
logger.debug(f"Trying to wake up with servers: {', '.join(self._opts.servers)}")
|
||||||
server_clients = find_available_island_apis(self._opts.servers)
|
server_clients = find_available_island_apis(
|
||||||
|
self._opts.servers, HTTPIslandAPIClientFactory(self._agent_event_serializer_registry)
|
||||||
|
)
|
||||||
|
|
||||||
server, island_api_client = self._select_server(server_clients)
|
server, island_api_client = self._select_server(server_clients)
|
||||||
|
|
||||||
|
@ -206,8 +211,6 @@ class InfectionMonkey:
|
||||||
if firewall.is_enabled():
|
if firewall.is_enabled():
|
||||||
firewall.add_firewall_rule()
|
firewall.add_firewall_rule()
|
||||||
|
|
||||||
self._agent_event_serializer_registry = self._setup_agent_event_serializers()
|
|
||||||
|
|
||||||
self._control_channel = ControlChannel(self._control_client.server_address, GUID)
|
self._control_channel = ControlChannel(self._control_client.server_address, GUID)
|
||||||
self._control_channel.register_agent(self._opts.parent)
|
self._control_channel.register_agent(self._opts.parent)
|
||||||
|
|
||||||
|
@ -247,7 +250,7 @@ class InfectionMonkey:
|
||||||
)
|
)
|
||||||
|
|
||||||
event_queue = PyPubSubAgentEventQueue(Publisher())
|
event_queue = PyPubSubAgentEventQueue(Publisher())
|
||||||
InfectionMonkey._subscribe_events(
|
self._subscribe_events(
|
||||||
event_queue,
|
event_queue,
|
||||||
propagation_credentials_repository,
|
propagation_credentials_repository,
|
||||||
self._control_client.server_address,
|
self._control_client.server_address,
|
||||||
|
@ -273,8 +276,8 @@ class InfectionMonkey:
|
||||||
propagation_credentials_repository,
|
propagation_credentials_repository,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _subscribe_events(
|
def _subscribe_events(
|
||||||
|
self,
|
||||||
event_queue: IAgentEventQueue,
|
event_queue: IAgentEventQueue,
|
||||||
propagation_credentials_repository: IPropagationCredentialsRepository,
|
propagation_credentials_repository: IPropagationCredentialsRepository,
|
||||||
server_address: str,
|
server_address: str,
|
||||||
|
@ -287,7 +290,7 @@ class InfectionMonkey:
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
event_queue.subscribe_all_events(
|
event_queue.subscribe_all_events(
|
||||||
AgentEventForwarder(server_address, agent_event_serializer_registry).send_event
|
AgentEventForwarder(self._island_api_client, agent_event_serializer_registry).send_event
|
||||||
)
|
)
|
||||||
|
|
||||||
def _build_puppet(
|
def _build_puppet(
|
||||||
|
|
|
@ -7,7 +7,7 @@ from typing import Dict, Iterable, Iterator, Mapping, MutableMapping, Optional,
|
||||||
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT
|
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT
|
||||||
from common.network.network_utils import address_to_ip_port
|
from common.network.network_utils import address_to_ip_port
|
||||||
from infection_monkey.island_api_client import (
|
from infection_monkey.island_api_client import (
|
||||||
HTTPIslandAPIClient,
|
AbstractIslandAPIClientFactory,
|
||||||
IIslandAPIClient,
|
IIslandAPIClient,
|
||||||
IslandAPIConnectionError,
|
IslandAPIConnectionError,
|
||||||
IslandAPIError,
|
IslandAPIError,
|
||||||
|
@ -27,7 +27,9 @@ logger = logging.getLogger(__name__)
|
||||||
NUM_FIND_SERVER_WORKERS = 32
|
NUM_FIND_SERVER_WORKERS = 32
|
||||||
|
|
||||||
|
|
||||||
def find_available_island_apis(servers: Iterable[str]) -> Mapping[str, Optional[IIslandAPIClient]]:
|
def find_available_island_apis(
|
||||||
|
servers: Iterable[str], island_api_client_factory: AbstractIslandAPIClientFactory
|
||||||
|
) -> Mapping[str, Optional[IIslandAPIClient]]:
|
||||||
server_list = list(servers)
|
server_list = list(servers)
|
||||||
server_iterator = ThreadSafeIterator(server_list.__iter__())
|
server_iterator = ThreadSafeIterator(server_list.__iter__())
|
||||||
server_results: Dict[str, Tuple[bool, IIslandAPIClient]] = {}
|
server_results: Dict[str, Tuple[bool, IIslandAPIClient]] = {}
|
||||||
|
@ -35,7 +37,7 @@ def find_available_island_apis(servers: Iterable[str]) -> Mapping[str, Optional[
|
||||||
run_worker_threads(
|
run_worker_threads(
|
||||||
_find_island_server,
|
_find_island_server,
|
||||||
"FindIslandServer",
|
"FindIslandServer",
|
||||||
args=(server_iterator, server_results),
|
args=(server_iterator, server_results, island_api_client_factory),
|
||||||
num_workers=NUM_FIND_SERVER_WORKERS,
|
num_workers=NUM_FIND_SERVER_WORKERS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -43,18 +45,25 @@ def find_available_island_apis(servers: Iterable[str]) -> Mapping[str, Optional[
|
||||||
|
|
||||||
|
|
||||||
def _find_island_server(
|
def _find_island_server(
|
||||||
servers: Iterator[str], server_status: MutableMapping[str, Optional[IIslandAPIClient]]
|
servers: Iterator[str],
|
||||||
|
server_status: MutableMapping[str, Optional[IIslandAPIClient]],
|
||||||
|
island_api_client_factory: AbstractIslandAPIClientFactory,
|
||||||
):
|
):
|
||||||
with suppress(StopIteration):
|
with suppress(StopIteration):
|
||||||
server = next(servers)
|
server = next(servers)
|
||||||
server_status[server] = _check_if_island_server(server)
|
server_status[server] = _check_if_island_server(server, island_api_client_factory)
|
||||||
|
|
||||||
|
|
||||||
def _check_if_island_server(server: str) -> IIslandAPIClient:
|
def _check_if_island_server(
|
||||||
|
server: str, island_api_client_factory: AbstractIslandAPIClientFactory
|
||||||
|
) -> IIslandAPIClient:
|
||||||
logger.debug(f"Trying to connect to server: {server}")
|
logger.debug(f"Trying to connect to server: {server}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return HTTPIslandAPIClient(server)
|
client = island_api_client_factory.create_island_api_client()
|
||||||
|
client.connect(server)
|
||||||
|
|
||||||
|
return client
|
||||||
except IslandAPIConnectionError as err:
|
except IslandAPIConnectionError as err:
|
||||||
logger.error(f"Unable to connect to server/relay {server}: {err}")
|
logger.error(f"Unable to connect to server/relay {server}: {err}")
|
||||||
except IslandAPITimeoutError as err:
|
except IslandAPITimeoutError as err:
|
||||||
|
|
|
@ -1,8 +1,15 @@
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
import requests_mock
|
import requests_mock
|
||||||
|
|
||||||
from common import OperatingSystem
|
from common import OperatingSystem
|
||||||
|
from common.agent_event_serializers import (
|
||||||
|
AgentEventSerializerRegistry,
|
||||||
|
PydanticAgentEventSerializer,
|
||||||
|
)
|
||||||
|
from common.agent_events import AbstractAgentEvent
|
||||||
from infection_monkey.island_api_client import (
|
from infection_monkey.island_api_client import (
|
||||||
HTTPIslandAPIClient,
|
HTTPIslandAPIClient,
|
||||||
IslandAPIConnectionError,
|
IslandAPIConnectionError,
|
||||||
|
@ -20,6 +27,35 @@ ISLAND_URI = f"https://{SERVER}/api?action=is-up"
|
||||||
ISLAND_SEND_LOG_URI = f"https://{SERVER}/api/log"
|
ISLAND_SEND_LOG_URI = f"https://{SERVER}/api/log"
|
||||||
ISLAND_GET_PBA_FILE_URI = f"https://{SERVER}/api/pba/download/{PBA_FILE}"
|
ISLAND_GET_PBA_FILE_URI = f"https://{SERVER}/api/pba/download/{PBA_FILE}"
|
||||||
ISLAND_GET_AGENT_BINARY_URI = f"https://{SERVER}/api/agent-binaries/{WINDOWS}"
|
ISLAND_GET_AGENT_BINARY_URI = f"https://{SERVER}/api/agent-binaries/{WINDOWS}"
|
||||||
|
ISLAND_SEND_EVENTS_URI = f"https://{SERVER}/api/agent-events"
|
||||||
|
|
||||||
|
AGENT_ID = UUID("80988359-a1cd-42a2-9b47-5b94b37cd673")
|
||||||
|
|
||||||
|
|
||||||
|
class Event1(AbstractAgentEvent):
|
||||||
|
a: int
|
||||||
|
|
||||||
|
|
||||||
|
class Event2(AbstractAgentEvent):
|
||||||
|
b: str
|
||||||
|
|
||||||
|
|
||||||
|
class Event3(AbstractAgentEvent):
|
||||||
|
c: int
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def agent_event_serializer_registry():
|
||||||
|
agent_event_serializer_registry = AgentEventSerializerRegistry()
|
||||||
|
agent_event_serializer_registry[Event1] = PydanticAgentEventSerializer(Event1)
|
||||||
|
agent_event_serializer_registry[Event2] = PydanticAgentEventSerializer(Event2)
|
||||||
|
|
||||||
|
return agent_event_serializer_registry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def island_api_client(agent_event_serializer_registry):
|
||||||
|
return HTTPIslandAPIClient(agent_event_serializer_registry)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -30,12 +66,12 @@ ISLAND_GET_AGENT_BINARY_URI = f"https://{SERVER}/api/agent-binaries/{WINDOWS}"
|
||||||
(Exception, IslandAPIError),
|
(Exception, IslandAPIError),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_island_api_client(actual_error, expected_error):
|
def test_island_api_client(island_api_client, actual_error, expected_error):
|
||||||
with requests_mock.Mocker() as m:
|
with requests_mock.Mocker() as m:
|
||||||
m.get(ISLAND_URI, exc=actual_error)
|
m.get(ISLAND_URI, exc=actual_error)
|
||||||
|
|
||||||
with pytest.raises(expected_error):
|
with pytest.raises(expected_error):
|
||||||
HTTPIslandAPIClient(SERVER)
|
island_api_client.connect(SERVER)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -45,12 +81,12 @@ def test_island_api_client(actual_error, expected_error):
|
||||||
(501, IslandAPIRequestFailedError),
|
(501, IslandAPIRequestFailedError),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_island_api_client__status_code(status_code, expected_error):
|
def test_island_api_client__status_code(island_api_client, status_code, expected_error):
|
||||||
with requests_mock.Mocker() as m:
|
with requests_mock.Mocker() as m:
|
||||||
m.get(ISLAND_URI, status_code=status_code)
|
m.get(ISLAND_URI, status_code=status_code)
|
||||||
|
|
||||||
with pytest.raises(expected_error):
|
with pytest.raises(expected_error):
|
||||||
HTTPIslandAPIClient(SERVER)
|
island_api_client.connect(SERVER)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -61,10 +97,10 @@ def test_island_api_client__status_code(status_code, expected_error):
|
||||||
(Exception, IslandAPIError),
|
(Exception, IslandAPIError),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_island_api_client__send_log(actual_error, expected_error):
|
def test_island_api_client__send_log(island_api_client, actual_error, expected_error):
|
||||||
with requests_mock.Mocker() as m:
|
with requests_mock.Mocker() as m:
|
||||||
m.get(ISLAND_URI)
|
m.get(ISLAND_URI)
|
||||||
island_api_client = HTTPIslandAPIClient(SERVER)
|
island_api_client.connect(SERVER)
|
||||||
|
|
||||||
with pytest.raises(expected_error):
|
with pytest.raises(expected_error):
|
||||||
m.post(ISLAND_SEND_LOG_URI, exc=actual_error)
|
m.post(ISLAND_SEND_LOG_URI, exc=actual_error)
|
||||||
|
@ -78,10 +114,10 @@ def test_island_api_client__send_log(actual_error, expected_error):
|
||||||
(501, IslandAPIRequestFailedError),
|
(501, IslandAPIRequestFailedError),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_island_api_client_send_log__status_code(status_code, expected_error):
|
def test_island_api_client_send_log__status_code(island_api_client, status_code, expected_error):
|
||||||
with requests_mock.Mocker() as m:
|
with requests_mock.Mocker() as m:
|
||||||
m.get(ISLAND_URI)
|
m.get(ISLAND_URI)
|
||||||
island_api_client = HTTPIslandAPIClient(SERVER)
|
island_api_client.connect(SERVER)
|
||||||
|
|
||||||
with pytest.raises(expected_error):
|
with pytest.raises(expected_error):
|
||||||
m.post(ISLAND_SEND_LOG_URI, status_code=status_code)
|
m.post(ISLAND_SEND_LOG_URI, status_code=status_code)
|
||||||
|
@ -96,10 +132,10 @@ def test_island_api_client_send_log__status_code(status_code, expected_error):
|
||||||
(Exception, IslandAPIError),
|
(Exception, IslandAPIError),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_island_api_client__get_pba_file(actual_error, expected_error):
|
def test_island_api_client__get_pba_file(island_api_client, actual_error, expected_error):
|
||||||
with requests_mock.Mocker() as m:
|
with requests_mock.Mocker() as m:
|
||||||
m.get(ISLAND_URI)
|
m.get(ISLAND_URI)
|
||||||
island_api_client = HTTPIslandAPIClient(SERVER)
|
island_api_client.connect(SERVER)
|
||||||
|
|
||||||
with pytest.raises(expected_error):
|
with pytest.raises(expected_error):
|
||||||
m.get(ISLAND_GET_PBA_FILE_URI, exc=actual_error)
|
m.get(ISLAND_GET_PBA_FILE_URI, exc=actual_error)
|
||||||
|
@ -113,10 +149,12 @@ def test_island_api_client__get_pba_file(actual_error, expected_error):
|
||||||
(501, IslandAPIRequestFailedError),
|
(501, IslandAPIRequestFailedError),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_island_api_client_get_pba_file__status_code(status_code, expected_error):
|
def test_island_api_client_get_pba_file__status_code(
|
||||||
|
island_api_client, status_code, expected_error
|
||||||
|
):
|
||||||
with requests_mock.Mocker() as m:
|
with requests_mock.Mocker() as m:
|
||||||
m.get(ISLAND_URI)
|
m.get(ISLAND_URI)
|
||||||
island_api_client = HTTPIslandAPIClient(SERVER)
|
island_api_client.connect(SERVER)
|
||||||
|
|
||||||
with pytest.raises(expected_error):
|
with pytest.raises(expected_error):
|
||||||
m.get(ISLAND_GET_PBA_FILE_URI, status_code=status_code)
|
m.get(ISLAND_GET_PBA_FILE_URI, status_code=status_code)
|
||||||
|
@ -131,10 +169,10 @@ def test_island_api_client_get_pba_file__status_code(status_code, expected_error
|
||||||
(Exception, IslandAPIError),
|
(Exception, IslandAPIError),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_island_api_client__get_agent_binary(actual_error, expected_error):
|
def test_island_api_client__get_agent_binary(island_api_client, actual_error, expected_error):
|
||||||
with requests_mock.Mocker() as m:
|
with requests_mock.Mocker() as m:
|
||||||
m.get(ISLAND_URI)
|
m.get(ISLAND_URI)
|
||||||
island_api_client = HTTPIslandAPIClient(SERVER)
|
island_api_client.connect(SERVER)
|
||||||
|
|
||||||
with pytest.raises(expected_error):
|
with pytest.raises(expected_error):
|
||||||
m.get(ISLAND_GET_AGENT_BINARY_URI, exc=actual_error)
|
m.get(ISLAND_GET_AGENT_BINARY_URI, exc=actual_error)
|
||||||
|
@ -148,11 +186,92 @@ def test_island_api_client__get_agent_binary(actual_error, expected_error):
|
||||||
(501, IslandAPIRequestFailedError),
|
(501, IslandAPIRequestFailedError),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_island_api_client__get_agent_binary_status_code(status_code, expected_error):
|
def test_island_api_client__get_agent_binary_status_code(
|
||||||
|
island_api_client, status_code, expected_error
|
||||||
|
):
|
||||||
with requests_mock.Mocker() as m:
|
with requests_mock.Mocker() as m:
|
||||||
m.get(ISLAND_URI)
|
m.get(ISLAND_URI)
|
||||||
island_api_client = HTTPIslandAPIClient(SERVER)
|
island_api_client.connect(SERVER)
|
||||||
|
|
||||||
with pytest.raises(expected_error):
|
with pytest.raises(expected_error):
|
||||||
m.get(ISLAND_GET_AGENT_BINARY_URI, status_code=status_code)
|
m.get(ISLAND_GET_AGENT_BINARY_URI, status_code=status_code)
|
||||||
island_api_client.get_agent_binary(operating_system=OperatingSystem.WINDOWS)
|
island_api_client.get_agent_binary(operating_system=OperatingSystem.WINDOWS)
|
||||||
|
|
||||||
|
|
||||||
|
def test_island_api_client_send_events__serialization(island_api_client):
|
||||||
|
events_to_send = [
|
||||||
|
Event1(source=AGENT_ID, timestamp=0, a=1),
|
||||||
|
Event2(source=AGENT_ID, timestamp=0, b="hello"),
|
||||||
|
]
|
||||||
|
expected_json = [
|
||||||
|
{
|
||||||
|
"source": "80988359-a1cd-42a2-9b47-5b94b37cd673",
|
||||||
|
"target": None,
|
||||||
|
"timestamp": 0.0,
|
||||||
|
"tags": [],
|
||||||
|
"a": 1,
|
||||||
|
"type": "Event1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"source": "80988359-a1cd-42a2-9b47-5b94b37cd673",
|
||||||
|
"target": None,
|
||||||
|
"timestamp": 0.0,
|
||||||
|
"tags": [],
|
||||||
|
"b": "hello",
|
||||||
|
"type": "Event2",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
with requests_mock.Mocker() as m:
|
||||||
|
m.get(ISLAND_URI)
|
||||||
|
m.post(ISLAND_SEND_EVENTS_URI)
|
||||||
|
island_api_client.connect(SERVER)
|
||||||
|
|
||||||
|
island_api_client.send_events(events=events_to_send)
|
||||||
|
|
||||||
|
assert m.last_request.json() == expected_json
|
||||||
|
|
||||||
|
|
||||||
|
def test_island_api_client_send_events__serialization_failed(island_api_client):
|
||||||
|
with requests_mock.Mocker() as m:
|
||||||
|
m.get(ISLAND_URI)
|
||||||
|
island_api_client.connect(SERVER)
|
||||||
|
|
||||||
|
with pytest.raises(IslandAPIRequestError):
|
||||||
|
m.post(ISLAND_SEND_EVENTS_URI)
|
||||||
|
island_api_client.send_events(events=[Event3(source=AGENT_ID, c=1)])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"actual_error, expected_error",
|
||||||
|
[
|
||||||
|
(requests.exceptions.ConnectionError, IslandAPIConnectionError),
|
||||||
|
(TimeoutError, IslandAPITimeoutError),
|
||||||
|
(Exception, IslandAPIError),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_island_api_client__send_events(island_api_client, actual_error, expected_error):
|
||||||
|
with requests_mock.Mocker() as m:
|
||||||
|
m.get(ISLAND_URI)
|
||||||
|
island_api_client.connect(SERVER)
|
||||||
|
|
||||||
|
with pytest.raises(expected_error):
|
||||||
|
m.post(ISLAND_SEND_EVENTS_URI, exc=actual_error)
|
||||||
|
island_api_client.send_events(events=[Event1(source=AGENT_ID, a=1)])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"status_code, expected_error",
|
||||||
|
[
|
||||||
|
(401, IslandAPIRequestError),
|
||||||
|
(501, IslandAPIRequestFailedError),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_island_api_client_send_events__status_code(island_api_client, status_code, expected_error):
|
||||||
|
with requests_mock.Mocker() as m:
|
||||||
|
m.get(ISLAND_URI)
|
||||||
|
island_api_client.connect(SERVER)
|
||||||
|
|
||||||
|
with pytest.raises(expected_error):
|
||||||
|
m.post(ISLAND_SEND_EVENTS_URI, status_code=status_code)
|
||||||
|
island_api_client.send_events(events=[Event1(source=AGENT_ID, a=1)])
|
||||||
|
|
|
@ -1,7 +1,12 @@
|
||||||
import pytest
|
import pytest
|
||||||
import requests_mock
|
import requests_mock
|
||||||
|
|
||||||
from infection_monkey.island_api_client import IIslandAPIClient, IslandAPIConnectionError
|
from common.agent_event_serializers import AgentEventSerializerRegistry
|
||||||
|
from infection_monkey.island_api_client import (
|
||||||
|
HTTPIslandAPIClientFactory,
|
||||||
|
IIslandAPIClient,
|
||||||
|
IslandAPIConnectionError,
|
||||||
|
)
|
||||||
from infection_monkey.network.relay.utils import find_available_island_apis
|
from infection_monkey.network.relay.utils import find_available_island_apis
|
||||||
|
|
||||||
SERVER_1 = "1.1.1.1:12312"
|
SERVER_1 = "1.1.1.1:12312"
|
||||||
|
@ -13,6 +18,11 @@ SERVER_4 = "4.4.4.4:5000"
|
||||||
servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4]
|
servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def island_api_client_factory():
|
||||||
|
return HTTPIslandAPIClientFactory(AgentEventSerializerRegistry())
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"expected_available_servers, server_response_pairs",
|
"expected_available_servers, server_response_pairs",
|
||||||
[
|
[
|
||||||
|
@ -24,12 +34,14 @@ servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4]
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_find_available_island_apis(expected_available_servers, server_response_pairs):
|
def test_find_available_island_apis(
|
||||||
|
expected_available_servers, server_response_pairs, island_api_client_factory
|
||||||
|
):
|
||||||
with requests_mock.Mocker() as mock:
|
with requests_mock.Mocker() as mock:
|
||||||
for server, response in server_response_pairs:
|
for server, response in server_response_pairs:
|
||||||
mock.get(f"https://{server}/api?action=is-up", **response)
|
mock.get(f"https://{server}/api?action=is-up", **response)
|
||||||
|
|
||||||
available_apis = find_available_island_apis(servers)
|
available_apis = find_available_island_apis(servers, island_api_client_factory)
|
||||||
|
|
||||||
assert len(available_apis) == len(server_response_pairs)
|
assert len(available_apis) == len(server_response_pairs)
|
||||||
|
|
||||||
|
@ -40,14 +52,14 @@ def test_find_available_island_apis(expected_available_servers, server_response_
|
||||||
assert island_api_client is None
|
assert island_api_client is None
|
||||||
|
|
||||||
|
|
||||||
def test_find_available_island_apis__multiple_successes():
|
def test_find_available_island_apis__multiple_successes(island_api_client_factory):
|
||||||
available_servers = [SERVER_2, SERVER_3]
|
available_servers = [SERVER_2, SERVER_3]
|
||||||
with requests_mock.Mocker() as mock:
|
with requests_mock.Mocker() as mock:
|
||||||
mock.get(f"https://{SERVER_1}/api?action=is-up", exc=IslandAPIConnectionError)
|
mock.get(f"https://{SERVER_1}/api?action=is-up", exc=IslandAPIConnectionError)
|
||||||
for server in available_servers:
|
for server in available_servers:
|
||||||
mock.get(f"https://{server}/api?action=is-up", text="")
|
mock.get(f"https://{server}/api?action=is-up", text="")
|
||||||
|
|
||||||
available_apis = find_available_island_apis(servers)
|
available_apis = find_available_island_apis(servers, island_api_client_factory)
|
||||||
|
|
||||||
assert available_apis[SERVER_1] is None
|
assert available_apis[SERVER_1] is None
|
||||||
assert available_apis[SERVER_4] is None
|
assert available_apis[SERVER_4] is None
|
||||||
|
|
|
@ -1,16 +1,22 @@
|
||||||
import time
|
import time
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import requests_mock
|
|
||||||
|
|
||||||
from infection_monkey.agent_event_forwarder import AGENT_EVENTS_API_URL, BatchingAgentEventForwarder
|
from infection_monkey.agent_event_forwarder import BatchingAgentEventForwarder
|
||||||
|
from infection_monkey.island_api_client import IIslandAPIClient
|
||||||
|
|
||||||
SERVER = "1.1.1.1:9999"
|
SERVER = "1.1.1.1:9999"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def event_sender():
|
def mock_island_api_client():
|
||||||
return BatchingAgentEventForwarder(SERVER, time_period=0.001)
|
return MagicMock(spec=IIslandAPIClient)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def event_sender(mock_island_api_client):
|
||||||
|
return BatchingAgentEventForwarder(mock_island_api_client, time_period=0.001)
|
||||||
|
|
||||||
|
|
||||||
# NOTE: If these tests are too slow or end up being racey, we can redesign AgentEventForwarder to
|
# NOTE: If these tests are too slow or end up being racey, we can redesign AgentEventForwarder to
|
||||||
|
@ -18,35 +24,29 @@ def event_sender():
|
||||||
# BatchingAgentEventForwarder would have unit tests, but AgentEventForwarder would not.
|
# BatchingAgentEventForwarder would have unit tests, but AgentEventForwarder would not.
|
||||||
|
|
||||||
|
|
||||||
def test_send_events(event_sender):
|
def test_send_events(event_sender, mock_island_api_client):
|
||||||
with requests_mock.Mocker() as mock:
|
event_sender.start()
|
||||||
mock.post(AGENT_EVENTS_API_URL % SERVER)
|
|
||||||
|
|
||||||
event_sender.start()
|
|
||||||
|
|
||||||
for _ in range(5):
|
|
||||||
event_sender.add_event_to_queue({})
|
|
||||||
time.sleep(0.02)
|
|
||||||
assert mock.call_count == 1
|
|
||||||
|
|
||||||
|
for _ in range(5):
|
||||||
event_sender.add_event_to_queue({})
|
event_sender.add_event_to_queue({})
|
||||||
time.sleep(0.02)
|
time.sleep(0.05)
|
||||||
assert mock.call_count == 2
|
assert mock_island_api_client.send_events.call_count == 1
|
||||||
|
|
||||||
event_sender.stop()
|
event_sender.add_event_to_queue({})
|
||||||
|
time.sleep(0.05)
|
||||||
|
assert mock_island_api_client.send_events.call_count == 2
|
||||||
|
|
||||||
|
event_sender.stop()
|
||||||
|
|
||||||
|
|
||||||
def test_send_remaining_events(event_sender):
|
def test_send_remaining_events(event_sender, mock_island_api_client):
|
||||||
with requests_mock.Mocker() as mock:
|
event_sender.start()
|
||||||
mock.post(AGENT_EVENTS_API_URL % SERVER)
|
|
||||||
|
|
||||||
event_sender.start()
|
|
||||||
|
|
||||||
for _ in range(5):
|
|
||||||
event_sender.add_event_to_queue({})
|
|
||||||
time.sleep(0.02)
|
|
||||||
assert mock.call_count == 1
|
|
||||||
|
|
||||||
|
for _ in range(5):
|
||||||
event_sender.add_event_to_queue({})
|
event_sender.add_event_to_queue({})
|
||||||
event_sender.stop()
|
time.sleep(0.05)
|
||||||
assert mock.call_count == 2
|
assert mock_island_api_client.send_events.call_count == 1
|
||||||
|
|
||||||
|
event_sender.add_event_to_queue({})
|
||||||
|
event_sender.stop()
|
||||||
|
assert mock_island_api_client.send_events.call_count == 2
|
||||||
|
|
Loading…
Reference in New Issue