Merge pull request #2349 from guardicore/2261-refactor-manual-agent-logic

2261 refactor manual agent logic
This commit is contained in:
Mike Salvatore 2022-09-23 13:45:55 -04:00 committed by GitHub
commit dbaa56c39d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
41 changed files with 589 additions and 334 deletions

View File

@ -24,6 +24,7 @@ Changelog](https://keepachangelog.com/en/1.0.0/).
- The ability to customize the file extension used by ransomware when - The ability to customize the file extension used by ransomware when
encrypting files. #1242 encrypting files. #1242
- `/api/agents` endpoint. - `/api/agents` endpoint.
- `/api/agent-signals` endpoint. #2261
### Changed ### Changed
- Reset workflow. Now it's possible to delete data gathered by agents without - Reset workflow. Now it's possible to delete data gathered by agents without
@ -64,6 +65,7 @@ Changelog](https://keepachangelog.com/en/1.0.0/).
- Tunneling to relays to provide better firewall evasion, faster Island - Tunneling to relays to provide better firewall evasion, faster Island
connection times, unlimited hops, and a more resilient way for agents to call connection times, unlimited hops, and a more resilient way for agents to call
home. #2216, #1583 home. #2216, #1583
- "/api/monkey-control/stop-all-agents" to "/api/agent-signals/terminate-all-agents" #2261
### Removed ### Removed
- VSFTPD exploiter. #1533 - VSFTPD exploiter. #1533
@ -109,6 +111,7 @@ Changelog](https://keepachangelog.com/en/1.0.0/).
- "/api/configuration/export" endpoint. #2002 - "/api/configuration/export" endpoint. #2002
- "/api/island-configuration" endpoint. #2003 - "/api/island-configuration" endpoint. #2003
- "-t/--tunnel" from agent command line arguments. #2216 - "-t/--tunnel" from agent command line arguments. #2216
- "/api/monkey-control/neets-to-stop". #2261
### Fixed ### Fixed
- A bug in network map page that caused delay of telemetry log loading. #1545 - A bug in network map page that caused delay of telemetry log loading. #1545

View File

@ -88,8 +88,9 @@ class MonkeyIslandClient(object):
@avoid_race_condition @avoid_race_condition
def kill_all_monkeys(self): def kill_all_monkeys(self):
# TODO change this request, because monkey-control resource got removed
response = self.requests.post_json( response = self.requests.post_json(
"api/monkey-control/stop-all-agents", json={"kill_time": time.time()} "api/agent-signals/terminate-all-agents", json={"terminate_time": time.time()}
) )
if response.ok: if response.ok:
LOGGER.info("Killing all monkeys after the test.") LOGGER.info("Killing all monkeys after the test.")

View File

@ -7,3 +7,4 @@ from .operating_system import OperatingSystem
from . import types from . import types
from . import base_models from . import base_models
from .agent_registration_data import AgentRegistrationData from .agent_registration_data import AgentRegistrationData
from .agent_signals import AgentSignals

View File

@ -0,0 +1,8 @@
from datetime import datetime
from typing import Optional
from .base_models import InfectionMonkeyBaseModel
class AgentSignals(InfectionMonkeyBaseModel):
terminate: Optional[datetime]

View File

@ -6,7 +6,7 @@ from typing import List, Sequence
import requests import requests
from common import AgentRegistrationData, OperatingSystem from common import AgentRegistrationData, AgentSignals, OperatingSystem
from common.agent_configuration import AgentConfiguration from common.agent_configuration import AgentConfiguration
from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable
from common.agent_events import AbstractAgentEvent from common.agent_events import AbstractAgentEvent
@ -146,19 +146,6 @@ class HTTPIslandAPIClient(IIslandAPIClient):
) )
response.raise_for_status() response.raise_for_status()
@handle_island_errors
@convert_json_error_to_island_api_error
def should_agent_stop(self, agent_id: str) -> bool:
url = f"{self._api_url}/monkey-control/needs-to-stop/{agent_id}"
response = requests.get( # noqa: DUO123
url,
verify=False,
timeout=SHORT_REQUEST_TIMEOUT,
)
response.raise_for_status()
return response.json()["stop_agent"]
@handle_island_errors @handle_island_errors
@convert_json_error_to_island_api_error @convert_json_error_to_island_api_error
def get_config(self) -> AgentConfiguration: def get_config(self) -> AgentConfiguration:
@ -199,6 +186,18 @@ class HTTPIslandAPIClient(IIslandAPIClient):
return serialized_events return serialized_events
@handle_island_errors
@convert_json_error_to_island_api_error
def get_agent_signals(self, agent_id: str) -> AgentSignals:
url = f"{self._api_url}/agent-signals/{agent_id}"
response = requests.get( # noqa: DUO123
url,
verify=False,
timeout=SHORT_REQUEST_TIMEOUT,
)
response.raise_for_status()
return AgentSignals(**response.json())
class HTTPIslandAPIClientFactory(AbstractIslandAPIClientFactory): class HTTPIslandAPIClientFactory(AbstractIslandAPIClientFactory):
def __init__( def __init__(

View File

@ -1,7 +1,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Sequence from typing import Sequence
from common import AgentRegistrationData, OperatingSystem from common import AgentRegistrationData, AgentSignals, OperatingSystem
from common.agent_configuration import AgentConfiguration from common.agent_configuration import AgentConfiguration
from common.agent_events import AbstractAgentEvent from common.agent_events import AbstractAgentEvent
from common.credentials import Credentials from common.credentials import Credentials
@ -107,19 +107,6 @@ class IIslandAPIClient(ABC):
:raises IslandAPITimeoutError: If the command timed out :raises IslandAPITimeoutError: If the command timed out
""" """
@abstractmethod
def should_agent_stop(self, agent_id: str) -> bool:
"""
Check with the island to see if the agent should stop
:param agent_id: The agent identifier for the agent to check
:raises IslandAPIConnectionError: If the client could not connect to the island
:raises IslandAPIRequestError: If there was a problem with the client request
:raises IslandAPIRequestFailedError: If the server experienced an error
:raises IslandAPITimeoutError: If the command timed out
:return: True if the agent should stop, otherwise False
"""
@abstractmethod @abstractmethod
def get_config(self) -> AgentConfiguration: def get_config(self) -> AgentConfiguration:
""" """
@ -143,3 +130,16 @@ class IIslandAPIClient(ABC):
:raises IslandAPITimeoutError: If the command timed out :raises IslandAPITimeoutError: If the command timed out
:return: Credentials :return: Credentials
""" """
@abstractmethod
def get_agent_signals(self, agent_id: str) -> AgentSignals:
"""
Gets an agent's signals from the island
:param agent_id: ID of the agent whose signals should be retrieved
:raises IslandAPIConnectionError: If the client could not connect to the island
:raises IslandAPIRequestError: If there was a problem with the client request
:raises IslandAPIRequestFailedError: If the server experienced an error
:raises IslandAPITimeoutError: If the command timed out
:return: The relevant agent's signals
"""

View File

@ -36,7 +36,8 @@ class ControlChannel(IControlChannel):
if not self._control_channel_server: if not self._control_channel_server:
logger.error("Agent should stop because it can't connect to the C&C server.") logger.error("Agent should stop because it can't connect to the C&C server.")
return True return True
return self._island_api_client.should_agent_stop(self._agent_id) agent_signals = self._island_api_client.get_agent_signals(self._agent_id)
return agent_signals.terminate is not None
@handle_island_api_errors @handle_island_api_errors
def get_config(self) -> AgentConfiguration: def get_config(self) -> AgentConfiguration:

View File

@ -123,7 +123,7 @@ class InfectionMonkey:
self._control_client = ControlClient( self._control_client = ControlClient(
server_address=server, island_api_client=self._island_api_client server_address=server, island_api_client=self._island_api_client
) )
self._control_channel = ControlChannel(server, GUID, self._island_api_client) self._control_channel = ControlChannel(server, get_agent_id(), self._island_api_client)
self._register_agent(server) self._register_agent(server)
# TODO Refactor the telemetry messengers to accept control client # TODO Refactor the telemetry messengers to accept control client

View File

@ -15,6 +15,7 @@ from monkey_island.cc.resources import (
AgentConfiguration, AgentConfiguration,
AgentEvents, AgentEvents,
Agents, Agents,
AgentSignals,
ClearSimulationData, ClearSimulationData,
IPAddresses, IPAddresses,
IslandLog, IslandLog,
@ -23,9 +24,9 @@ from monkey_island.cc.resources import (
PropagationCredentials, PropagationCredentials,
RemoteRun, RemoteRun,
ResetAgentConfiguration, ResetAgentConfiguration,
TerminateAllAgents,
) )
from monkey_island.cc.resources.AbstractResource import AbstractResource from monkey_island.cc.resources.AbstractResource import AbstractResource
from monkey_island.cc.resources.agent_controls import StopAgentCheck, StopAllAgents
from monkey_island.cc.resources.attack.attack_report import AttackReport from monkey_island.cc.resources.attack.attack_report import AttackReport
from monkey_island.cc.resources.auth import Authenticate, Register, RegistrationStatus, init_jwt from monkey_island.cc.resources.auth import Authenticate, Register, RegistrationStatus, init_jwt
from monkey_island.cc.resources.blackbox.log_blackbox_endpoint import LogBlackboxEndpoint from monkey_island.cc.resources.blackbox.log_blackbox_endpoint import LogBlackboxEndpoint
@ -188,6 +189,7 @@ def init_restful_endpoints(api: FlaskDIWrapper):
api.add_resource(IPAddresses) api.add_resource(IPAddresses)
api.add_resource(AgentEvents) api.add_resource(AgentEvents)
api.add_resource(AgentSignals)
# API Spec: These two should be the same resource, GET for download and POST for upload # API Spec: These two should be the same resource, GET for download and POST for upload
api.add_resource(PBAFileDownload) api.add_resource(PBAFileDownload)
@ -196,8 +198,6 @@ def init_restful_endpoints(api: FlaskDIWrapper):
api.add_resource(PropagationCredentials) api.add_resource(PropagationCredentials)
api.add_resource(RemoteRun) api.add_resource(RemoteRun)
api.add_resource(Version) api.add_resource(Version)
api.add_resource(StopAgentCheck)
api.add_resource(StopAllAgents)
# Resources used by black box tests # Resources used by black box tests
# API Spec: Fix all the following endpoints, see comments in the resource classes # API Spec: Fix all the following endpoints, see comments in the resource classes
@ -211,6 +211,7 @@ def init_restful_endpoints(api: FlaskDIWrapper):
def init_rpc_endpoints(api: FlaskDIWrapper): def init_rpc_endpoints(api: FlaskDIWrapper):
api.add_resource(ResetAgentConfiguration) api.add_resource(ResetAgentConfiguration)
api.add_resource(ClearSimulationData) api.add_resource(ClearSimulationData)
api.add_resource(TerminateAllAgents)
def init_app(mongo_url: str, container: DIContainer): def init_app(mongo_url: str, container: DIContainer):

View File

@ -9,6 +9,7 @@ class IslandEventTopic(Enum):
CLEAR_SIMULATION_DATA = auto() CLEAR_SIMULATION_DATA = auto()
RESET_AGENT_CONFIGURATION = auto() RESET_AGENT_CONFIGURATION = auto()
SET_ISLAND_MODE = auto() SET_ISLAND_MODE = auto()
TERMINATE_AGENTS = auto()
class IIslandEventQueue(ABC): class IIslandEventQueue(ABC):

View File

@ -1 +0,0 @@
from .agent_controls import AgentControls

View File

@ -1,8 +0,0 @@
from mongoengine import Document, FloatField
# TODO rename to Simulation, add other metadata
class AgentControls(Document):
# Timestamp of the last "kill all agents" command
last_stop_all = FloatField(default=None)

View File

@ -22,10 +22,6 @@ from monkey_island.cc.models.monkey_ttl import MonkeyTtl, create_monkey_ttl_docu
from monkey_island.cc.server_utils.consts import DEFAULT_MONKEY_TTL_EXPIRY_DURATION_IN_SECONDS from monkey_island.cc.server_utils.consts import DEFAULT_MONKEY_TTL_EXPIRY_DURATION_IN_SECONDS
class ParentNotFoundError(Exception):
"""Raise when trying to get a parent of monkey that doesn't have one"""
class Monkey(Document): class Monkey(Document):
""" """
This class has 2 main section: This class has 2 main section:
@ -98,18 +94,6 @@ class Monkey(Document):
monkey_is_dead = True monkey_is_dead = True
return monkey_is_dead return monkey_is_dead
def has_parent(self):
for p in self.parent:
if p[0] != self.guid:
return True
return False
def get_parent(self):
if self.has_parent():
return Monkey.objects(guid=self.parent[0][0]).first()
else:
raise ParentNotFoundError(f"No parent was found for agent with GUID {self.guid}")
def get_os(self): def get_os(self):
os = "unknown" os = "unknown"
if self.description.lower().find("linux") != -1: if self.description.lower().find("linux") != -1:

View File

@ -1,6 +1,8 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime
from enum import Enum from enum import Enum
from typing import Optional
from common.base_models import InfectionMonkeyBaseModel from common.base_models import InfectionMonkeyBaseModel
@ -13,3 +15,4 @@ class IslandMode(Enum):
class Simulation(InfectionMonkeyBaseModel): class Simulation(InfectionMonkeyBaseModel):
mode: IslandMode = IslandMode.UNSET mode: IslandMode = IslandMode.UNSET
terminate_signal_time: Optional[datetime] = None

View File

@ -16,7 +16,7 @@ class IAgentRepository(ABC):
already exists, update it. already exists, update it.
:param agent: The `agent` to be inserted or updated :param agent: The `agent` to be inserted or updated
:raises StorageError: If an error occurred while attempting to store the `Agent` :raises StorageError: If an error occurs while attempting to store the `Agent`
""" """
@abstractmethod @abstractmethod
@ -28,7 +28,7 @@ class IAgentRepository(ABC):
:return: An `Agent` with a matching `id` :return: An `Agent` with a matching `id`
:raises UnknownRecordError: If an `Agent` with the specified `id` does not exist in the :raises UnknownRecordError: If an `Agent` with the specified `id` does not exist in the
repository repository
:raises RetrievalError: If an error occurred while attempting to retrieve the `Agent` :raises RetrievalError: If an error occurs while attempting to retrieve the `Agent`
""" """
@abstractmethod @abstractmethod
@ -37,7 +37,18 @@ class IAgentRepository(ABC):
Get all `Agents` that are currently running Get all `Agents` that are currently running
:return: All `Agents` that are currently running :return: All `Agents` that are currently running
:raises RetrievalError: If an error occurred while attempting to retrieve the `Agents` :raises RetrievalError: If an error occurs while attempting to retrieve the `Agents`
"""
@abstractmethod
def get_progenitor(self, agent: Agent) -> Agent:
"""
Gets the progenitor `Agent` for the agent.
:param agent: The Agent for which we want the progenitor
:return: `Agent` progenitor ( an initial agent that started the exploitation chain)
:raises RetrievalError: If an error occurrs while attempting to retrieve the `Agent`
:raises UnknownRecordError: If the agent ID is not in the repository
""" """
@abstractmethod @abstractmethod

View File

@ -26,7 +26,7 @@ class IMachineRepository(ABC):
`Machine` already exists, update it. `Machine` already exists, update it.
:param machine: The `Machine` to be inserted or updated :param machine: The `Machine` to be inserted or updated
:raises StorageError: If an error occurred while attempting to store the `Machine` :raises StorageError: If an error occurs while attempting to store the `Machine`
""" """
@abstractmethod @abstractmethod
@ -38,7 +38,7 @@ class IMachineRepository(ABC):
:return: A `Machine` with a matching `id` :return: A `Machine` with a matching `id`
:raises UnknownRecordError: If a `Machine` with the specified `id` does not exist in the :raises UnknownRecordError: If a `Machine` with the specified `id` does not exist in the
repository repository
:raises RetrievalError: If an error occurred while attempting to retrieve the `Machine` :raises RetrievalError: If an error occurs while attempting to retrieve the `Machine`
""" """
@abstractmethod @abstractmethod
@ -50,7 +50,7 @@ class IMachineRepository(ABC):
:return: A `Machine` with a matching `hardware_id` :return: A `Machine` with a matching `hardware_id`
:raises UnknownRecordError: If a `Machine` with the specified `hardware_id` does not exist :raises UnknownRecordError: If a `Machine` with the specified `hardware_id` does not exist
in the repository in the repository
:raises RetrievalError: If an error occurred while attempting to retrieve the `Machine` :raises RetrievalError: If an error occurs while attempting to retrieve the `Machine`
""" """
@abstractmethod @abstractmethod
@ -62,7 +62,7 @@ class IMachineRepository(ABC):
:return: A sequence of Machines that have a network interface with a matching IP :return: A sequence of Machines that have a network interface with a matching IP
:raises UnknownRecordError: If a `Machine` with the specified `ip` does not exist in the :raises UnknownRecordError: If a `Machine` with the specified `ip` does not exist in the
repository repository
:raises RetrievalError: If an error occurred while attempting to retrieve the `Machine` :raises RetrievalError: If an error occurs while attempting to retrieve the `Machine`
""" """
@abstractmethod @abstractmethod
@ -70,6 +70,6 @@ class IMachineRepository(ABC):
""" """
Removes all data from the repository Removes all data from the repository
:raises RemovalError: If an error occurred while attempting to remove all `Machines` from :raises RemovalError: If an error occurs while attempting to remove all `Machines` from
the repository the repository
""" """

View File

@ -22,7 +22,7 @@ class INodeRepository(ABC):
:param src: The machine that the connection or communication originated from :param src: The machine that the connection or communication originated from
:param dst: The machine that the src communicated with :param dst: The machine that the src communicated with
:param communication_type: The way the machines communicated :param communication_type: The way the machines communicated
:raises StorageError: If an error occurred while attempting to upsert the Node :raises StorageError: If an error occurs while attempting to upsert the Node
""" """
@abstractmethod @abstractmethod
@ -31,7 +31,7 @@ class INodeRepository(ABC):
Return all nodes that are stored in the repository Return all nodes that are stored in the repository
:return: All known Nodes :return: All known Nodes
:raises RetrievalError: If an error occurred while attempting to retrieve the nodes :raises RetrievalError: If an error occurs while attempting to retrieve the nodes
""" """
@abstractmethod @abstractmethod
@ -39,6 +39,6 @@ class INodeRepository(ABC):
""" """
Removes all data from the repository Removes all data from the repository
:raises RemovalError: If an error occurred while attempting to remove all `Nodes` from the :raises RemovalError: If an error occurs while attempting to remove all `Nodes` from the
repository repository
""" """

View File

@ -58,6 +58,14 @@ class MongoAgentRepository(IAgentRepository):
except Exception as err: except Exception as err:
raise RetrievalError(f"Error retrieving running agents: {err}") raise RetrievalError(f"Error retrieving running agents: {err}")
def get_progenitor(self, agent: Agent) -> Agent:
if agent.parent_id is None:
return agent
parent = self.get_agent_by_id(agent.parent_id)
return self.get_progenitor(parent)
def reset(self): def reset(self):
try: try:
self._agents_collection.drop() self._agents_collection.drop()

View File

@ -10,3 +10,4 @@ from .pba_file_upload import PBAFileUpload, LINUX_PBA_TYPE, WINDOWS_PBA_TYPE
from .pba_file_download import PBAFileDownload from .pba_file_download import PBAFileDownload
from .agent_events import AgentEvents from .agent_events import AgentEvents
from .agents import Agents from .agents import Agents
from .agent_signals import AgentSignals, TerminateAllAgents

View File

@ -1,2 +0,0 @@
from .stop_all_agents import StopAllAgents
from .stop_agent_check import StopAgentCheck

View File

@ -1,11 +0,0 @@
from monkey_island.cc.resources.AbstractResource import AbstractResource
from monkey_island.cc.services.infection_lifecycle import should_agent_die
class StopAgentCheck(AbstractResource):
# API Spec: Rename to AgentStopStatus or something, endpoint for this could be
# "/api/agents/<GUID>/stop-status"
urls = ["/api/monkey-control/needs-to-stop/<int:monkey_guid>"]
def get(self, monkey_guid: int):
return {"stop_agent": should_agent_die(monkey_guid)}

View File

@ -1,27 +0,0 @@
import json
from flask import make_response, request
from monkey_island.cc.resources.AbstractResource import AbstractResource
from monkey_island.cc.resources.request_authentication import jwt_required
from monkey_island.cc.resources.utils.semaphores import agent_killing_mutex
from monkey_island.cc.services.infection_lifecycle import set_stop_all, should_agent_die
class StopAllAgents(AbstractResource):
# API Spec: This is an action and there's no "resource"; RPC-style endpoint?
urls = ["/api/monkey-control/stop-all-agents"]
@jwt_required
def post(self):
with agent_killing_mutex:
data = json.loads(request.data)
if data["kill_time"]:
set_stop_all(data["kill_time"])
return make_response({}, 200)
else:
return make_response({}, 400)
# API Spec: This is the exact same thing as what's in StopAgentCheck
def get(self, monkey_guid):
return {"stop_agent": should_agent_die(monkey_guid)}

View File

@ -0,0 +1,2 @@
from .agent_signals import AgentSignals
from .terminate_all_agents import TerminateAllAgents

View File

@ -0,0 +1,21 @@
import logging
from http import HTTPStatus
from monkey_island.cc.resources.AbstractResource import AbstractResource
from monkey_island.cc.services import AgentSignalsService
logger = logging.getLogger(__name__)
class AgentSignals(AbstractResource):
urls = ["/api/agent-signals/<string:agent_id>"]
def __init__(
self,
agent_signals_service: AgentSignalsService,
):
self._agent_signals_service = agent_signals_service
def get(self, agent_id: str):
agent_signals = self._agent_signals_service.get_signals(agent_id)
return agent_signals.dict(simplify=True), HTTPStatus.OK

View File

@ -0,0 +1,39 @@
import logging
from http import HTTPStatus
from json import JSONDecodeError
from flask import request
from monkey_island.cc.event_queue import IIslandEventQueue, IslandEventTopic
from monkey_island.cc.resources.AbstractResource import AbstractResource
from monkey_island.cc.resources.request_authentication import jwt_required
logger = logging.getLogger(__name__)
class TerminateAllAgents(AbstractResource):
urls = ["/api/agent-signals/terminate-all-agents"]
def __init__(
self,
island_event_queue: IIslandEventQueue,
):
self._island_event_queue = island_event_queue
@jwt_required
def post(self):
try:
terminate_timestamp = request.json["terminate_time"]
if terminate_timestamp is None:
raise ValueError("Terminate signal's timestamp is empty")
elif terminate_timestamp <= 0:
raise ValueError("Terminate signal's timestamp is not a positive integer")
self._island_event_queue.publish(
IslandEventTopic.TERMINATE_AGENTS, timestamp=terminate_timestamp
)
except (JSONDecodeError, TypeError, ValueError, KeyError) as err:
return {"error": err}, HTTPStatus.BAD_REQUEST
return {}, HTTPStatus.NO_CONTENT

View File

@ -1,3 +1,4 @@
from .agent_signals_service import AgentSignalsService
from .authentication_service import AuthenticationService from .authentication_service import AuthenticationService
from .aws import AWSService from .aws import AWSService

View File

@ -0,0 +1,55 @@
import logging
from datetime import datetime
from typing import Optional
from common.agent_signals import AgentSignals
from common.types import AgentID
from monkey_island.cc.models import Simulation
from monkey_island.cc.repository import IAgentRepository, ISimulationRepository
logger = logging.getLogger(__name__)
class AgentSignalsService:
def __init__(
self, simulation_repository: ISimulationRepository, agent_repository: IAgentRepository
):
self._simulation_repository = simulation_repository
self._agent_repository = agent_repository
def get_signals(self, agent_id: AgentID) -> AgentSignals:
"""
Gets the signals sent to a particular agent
:param agent_id: The ID of the agent whose signals need to be retrieved
:return: Signals sent to the relevant agent
"""
terminate_timestamp = self._get_terminate_signal_timestamp(agent_id)
return AgentSignals(terminate=terminate_timestamp)
def _get_terminate_signal_timestamp(self, agent_id: AgentID) -> Optional[datetime]:
simulation = self._simulation_repository.get_simulation()
terminate_all_signal_time = simulation.terminate_signal_time
if terminate_all_signal_time is None:
return None
agent = self._agent_repository.get_agent_by_id(agent_id)
if agent.start_time <= terminate_all_signal_time:
return terminate_all_signal_time
progenitor = self._agent_repository.get_progenitor(agent)
if progenitor.start_time <= terminate_all_signal_time:
return terminate_all_signal_time
return None
def on_terminate_agents_signal(self, timestamp: datetime):
"""
Updates the simulation repository with the terminate signal's timestamp
:param timestamp: Timestamp of the terminate signal
"""
simulation = self._simulation_repository.get_simulation()
updated_simulation = Simulation(mode=simulation.mode, terminate_signal_time=timestamp)
self._simulation_repository.save_simulation(updated_simulation)

View File

@ -4,7 +4,6 @@ from flask import jsonify
from monkey_island.cc.database import mongo from monkey_island.cc.database import mongo
from monkey_island.cc.models import Config from monkey_island.cc.models import Config
from monkey_island.cc.models.agent_controls import AgentControls
from monkey_island.cc.models.attack.attack_mitigations import AttackMitigations from monkey_island.cc.models.attack.attack_mitigations import AttackMitigations
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -24,7 +23,6 @@ class Database(object):
for x in mongo.db.collection_names() for x in mongo.db.collection_names()
if Database._should_drop(x, reset_config) if Database._should_drop(x, reset_config)
] ]
Database.init_agent_controls()
logger.info("DB was reset") logger.info("DB was reset")
return jsonify(status="OK") return jsonify(status="OK")
@ -44,10 +42,6 @@ class Database(object):
mongo.db[collection_name].drop() mongo.db[collection_name].drop()
logger.info("Dropped collection {}".format(collection_name)) logger.info("Dropped collection {}".format(collection_name))
@staticmethod
def init_agent_controls():
AgentControls().save()
@staticmethod @staticmethod
def is_mitigations_missing() -> bool: def is_mitigations_missing() -> bool:
return bool(AttackMitigations.COLLECTION_NAME not in mongo.db.list_collection_names()) return bool(AttackMitigations.COLLECTION_NAME not in mongo.db.list_collection_names())

View File

@ -1,7 +1,5 @@
import logging import logging
from monkey_island.cc.models import Monkey
from monkey_island.cc.models.agent_controls import AgentControls
from monkey_island.cc.services.node import NodeService from monkey_island.cc.services.node import NodeService
from monkey_island.cc.services.reporting.report import ReportService from monkey_island.cc.services.reporting.report import ReportService
from monkey_island.cc.services.reporting.report_generation_synchronisation import ( from monkey_island.cc.services.reporting.report_generation_synchronisation import (
@ -12,41 +10,6 @@ from monkey_island.cc.services.reporting.report_generation_synchronisation impor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def set_stop_all(time: float):
# This will use Agent and Simulation repositories
for monkey in Monkey.objects():
monkey.should_stop = True
monkey.save()
agent_controls = AgentControls.objects.first()
agent_controls.last_stop_all = time
agent_controls.save()
def should_agent_die(guid: int) -> bool:
monkey = Monkey.objects(guid=str(guid)).first()
return _should_agent_stop(monkey) or _is_monkey_killed_manually(monkey)
def _should_agent_stop(monkey: Monkey) -> bool:
if monkey.should_stop:
# Only stop the agent once, to allow further runs on that machine
monkey.should_stop = False
monkey.save()
return True
return False
def _is_monkey_killed_manually(monkey: Monkey) -> bool:
kill_timestamp = AgentControls.objects.first().last_stop_all
if kill_timestamp is None:
return False
if monkey.has_parent():
launch_timestamp = monkey.get_parent().launch_time
else:
launch_timestamp = monkey.launch_time
return int(kill_timestamp) >= int(launch_timestamp)
def get_completed_steps(): def get_completed_steps():
is_any_exists = NodeService.is_any_monkey_exists() is_any_exists = NodeService.is_any_monkey_exists()
infection_done = NodeService.is_monkey_finished_running() infection_done = NodeService.is_monkey_finished_running()

View File

@ -47,7 +47,7 @@ from monkey_island.cc.repository import (
) )
from monkey_island.cc.server_utils.consts import MONKEY_ISLAND_ABS_PATH from monkey_island.cc.server_utils.consts import MONKEY_ISLAND_ABS_PATH
from monkey_island.cc.server_utils.encryption import ILockableEncryptor, RepositoryEncryptor from monkey_island.cc.server_utils.encryption import ILockableEncryptor, RepositoryEncryptor
from monkey_island.cc.services import AWSService from monkey_island.cc.services import AgentSignalsService, AWSService
from monkey_island.cc.services.attack.technique_reports.T1003 import T1003, T1003GetReportData from monkey_island.cc.services.attack.technique_reports.T1003 import T1003, T1003GetReportData
from monkey_island.cc.services.run_local_monkey import LocalMonkeyRunService from monkey_island.cc.services.run_local_monkey import LocalMonkeyRunService
from monkey_island.cc.setup.mongo.mongo_setup import MONGO_URL from monkey_island.cc.setup.mongo.mongo_setup import MONGO_URL
@ -176,6 +176,7 @@ def _register_services(container: DIContainer):
container.register_instance(AWSService, container.resolve(AWSService)) container.register_instance(AWSService, container.resolve(AWSService))
container.register_instance(LocalMonkeyRunService, container.resolve(LocalMonkeyRunService)) container.register_instance(LocalMonkeyRunService, container.resolve(LocalMonkeyRunService))
container.register_instance(AuthenticationService, container.resolve(AuthenticationService)) container.register_instance(AuthenticationService, container.resolve(AuthenticationService))
container.register_instance(AgentSignalsService, container.resolve(AgentSignalsService))
def _dirty_hacks(container: DIContainer): def _dirty_hacks(container: DIContainer):

View File

@ -15,6 +15,7 @@ from monkey_island.cc.repository import (
INodeRepository, INodeRepository,
ISimulationRepository, ISimulationRepository,
) )
from monkey_island.cc.services import AgentSignalsService
from monkey_island.cc.services.database import Database from monkey_island.cc.services.database import Database
@ -25,6 +26,7 @@ def setup_island_event_handlers(container: DIContainer):
_subscribe_reset_agent_configuration_events(island_event_queue, container) _subscribe_reset_agent_configuration_events(island_event_queue, container)
_subscribe_clear_simulation_data_events(island_event_queue, container) _subscribe_clear_simulation_data_events(island_event_queue, container)
_subscribe_set_island_mode_events(island_event_queue, container) _subscribe_set_island_mode_events(island_event_queue, container)
_subscribe_terminate_agents_events(island_event_queue, container)
def _subscribe_agent_registration_events( def _subscribe_agent_registration_events(
@ -74,3 +76,13 @@ def _subscribe_set_island_mode_events(
simulation_repository = container.resolve(ISimulationRepository) simulation_repository = container.resolve(ISimulationRepository)
island_event_queue.subscribe(topic, simulation_repository.set_mode) island_event_queue.subscribe(topic, simulation_repository.set_mode)
def _subscribe_terminate_agents_events(
island_event_queue: IIslandEventQueue, container: DIContainer
):
topic = IslandEventTopic.TERMINATE_AGENTS
agent_signals_service = container.resolve(AgentSignalsService)
island_event_queue.subscribe(topic, agent_signals_service.on_terminate_agents_signal)

View File

@ -84,12 +84,12 @@ class MapPageComponent extends AuthComponent {
} }
killAllMonkeys = () => { killAllMonkeys = () => {
this.authFetch('/api/monkey-control/stop-all-agents', this.authFetch('/api/agent-signals/terminate-all-agents',
{ {
method: 'POST', method: 'POST',
headers: {'Content-Type': 'application/json'}, headers: {'Content-Type': 'application/json'},
// Python uses floating point seconds, Date.now uses milliseconds, so convert // Python uses floating point seconds, Date.now uses milliseconds, so convert
body: JSON.stringify({kill_time: Date.now() / 1000.0}) body: JSON.stringify({terminate_time: Date.now() / 1000.0})
}) })
.then(res => res.json()) .then(res => res.json())
.then(() => {this.setState({killPressed: true})}); .then(() => {this.setState({killPressed: true})});

View File

@ -4,7 +4,7 @@ import pytest
import requests import requests
import requests_mock import requests_mock
from common import OperatingSystem from common import AgentSignals, OperatingSystem
from common.agent_event_serializers import ( from common.agent_event_serializers import (
AgentEventSerializerRegistry, AgentEventSerializerRegistry,
PydanticAgentEventSerializer, PydanticAgentEventSerializer,
@ -33,6 +33,8 @@ AGENT_REGISTRATION = AgentRegistrationData(
network_interfaces=[], network_interfaces=[],
) )
TIMESTAMP = 123456789
ISLAND_URI = f"https://{SERVER}/api?action=is-up" ISLAND_URI = f"https://{SERVER}/api?action=is-up"
ISLAND_SEND_LOG_URI = f"https://{SERVER}/api/log" ISLAND_SEND_LOG_URI = f"https://{SERVER}/api/log"
ISLAND_GET_PBA_FILE_URI = f"https://{SERVER}/api/pba/download/{PBA_FILE}" ISLAND_GET_PBA_FILE_URI = f"https://{SERVER}/api/pba/download/{PBA_FILE}"
@ -42,6 +44,7 @@ ISLAND_REGISTER_AGENT_URI = f"https://{SERVER}/api/agents"
ISLAND_AGENT_STOP_URI = f"https://{SERVER}/api/monkey-control/needs-to-stop/{AGENT_ID}" ISLAND_AGENT_STOP_URI = f"https://{SERVER}/api/monkey-control/needs-to-stop/{AGENT_ID}"
ISLAND_GET_CONFIG_URI = f"https://{SERVER}/api/agent-configuration" ISLAND_GET_CONFIG_URI = f"https://{SERVER}/api/agent-configuration"
ISLAND_GET_PROPAGATION_CREDENTIALS_URI = f"https://{SERVER}/api/propagation-credentials" ISLAND_GET_PROPAGATION_CREDENTIALS_URI = f"https://{SERVER}/api/propagation-credentials"
ISLAND_GET_AGENT_SIGNALS = f"https://{SERVER}/api/agent-signals/{AGENT_ID}"
class Event1(AbstractAgentEvent): class Event1(AbstractAgentEvent):
@ -325,52 +328,6 @@ def test_island_api_client_register_agent__status_code(
island_api_client.register_agent(AGENT_REGISTRATION) island_api_client.register_agent(AGENT_REGISTRATION)
@pytest.mark.parametrize(
"actual_error, expected_error",
[
(requests.exceptions.ConnectionError, IslandAPIConnectionError),
(TimeoutError, IslandAPITimeoutError),
],
)
def test_island_api_client__should_agent_stop(island_api_client, actual_error, expected_error):
with requests_mock.Mocker() as m:
m.get(ISLAND_URI)
island_api_client.connect(SERVER)
with pytest.raises(expected_error):
m.get(ISLAND_AGENT_STOP_URI, exc=actual_error)
island_api_client.should_agent_stop(AGENT_ID)
@pytest.mark.parametrize(
"status_code, expected_error",
[
(401, IslandAPIRequestError),
(501, IslandAPIRequestFailedError),
],
)
def test_island_api_client_should_agent_stop__status_code(
island_api_client, status_code, expected_error
):
with requests_mock.Mocker() as m:
m.get(ISLAND_URI)
island_api_client.connect(SERVER)
with pytest.raises(expected_error):
m.get(ISLAND_AGENT_STOP_URI, status_code=status_code)
island_api_client.should_agent_stop(AGENT_ID)
def test_island_api_client_should_agent_stop__bad_json(island_api_client):
with requests_mock.Mocker() as m:
m.get(ISLAND_URI)
island_api_client.connect(SERVER)
with pytest.raises(IslandAPIRequestFailedError):
m.get(ISLAND_AGENT_STOP_URI, content=b"bad")
island_api_client.should_agent_stop(AGENT_ID)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"actual_error, expected_error", "actual_error, expected_error",
[ [
@ -461,3 +418,62 @@ def test_island_api_client_get_credentials_for_propagation__bad_json(island_api_
with pytest.raises(IslandAPIRequestFailedError): with pytest.raises(IslandAPIRequestFailedError):
m.get(ISLAND_GET_PROPAGATION_CREDENTIALS_URI, content=b"bad") m.get(ISLAND_GET_PROPAGATION_CREDENTIALS_URI, content=b"bad")
island_api_client.get_credentials_for_propagation() island_api_client.get_credentials_for_propagation()
@pytest.mark.parametrize(
"actual_error, expected_error",
[
(requests.exceptions.ConnectionError, IslandAPIConnectionError),
(TimeoutError, IslandAPITimeoutError),
],
)
def test_island_api_client__get_agent_signals(island_api_client, actual_error, expected_error):
with requests_mock.Mocker() as m:
m.get(ISLAND_URI)
island_api_client.connect(SERVER)
with pytest.raises(expected_error):
m.get(ISLAND_GET_AGENT_SIGNALS, exc=actual_error)
island_api_client.get_agent_signals(agent_id=AGENT_ID)
@pytest.mark.parametrize(
"status_code, expected_error",
[
(401, IslandAPIRequestError),
(501, IslandAPIRequestFailedError),
],
)
def test_island_api_client_get_agent_signals__status_code(
island_api_client, status_code, expected_error
):
with requests_mock.Mocker() as m:
m.get(ISLAND_URI)
island_api_client.connect(SERVER)
with pytest.raises(expected_error):
m.get(ISLAND_GET_AGENT_SIGNALS, status_code=status_code)
island_api_client.get_agent_signals(agent_id=AGENT_ID)
@pytest.mark.parametrize("timestamp", [TIMESTAMP, None])
def test_island_api_client_get_agent_signals(island_api_client, timestamp):
expected_agent_signals = AgentSignals(terminate=timestamp)
with requests_mock.Mocker() as m:
m.get(ISLAND_URI)
island_api_client.connect(SERVER)
m.get(ISLAND_GET_AGENT_SIGNALS, json={"terminate": timestamp})
actual_agent_signals = island_api_client.get_agent_signals(agent_id=AGENT_ID)
assert actual_agent_signals == expected_agent_signals
def test_island_api_client_get_agent_signals__bad_json(island_api_client):
with requests_mock.Mocker() as m:
m.get(ISLAND_URI)
island_api_client.connect(SERVER)
with pytest.raises(IslandAPIError):
m.get(ISLAND_GET_AGENT_SIGNALS, json={"bogus": "vogus"})
island_api_client.get_agent_signals(agent_id=AGENT_ID)

View File

@ -1,8 +1,10 @@
from typing import Optional
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
from infection_monkey.i_control_channel import IslandCommunicationError from common import AgentSignals
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
from infection_monkey.island_api_client import ( from infection_monkey.island_api_client import (
IIslandAPIClient, IIslandAPIClient,
IslandAPIConnectionError, IslandAPIConnectionError,
@ -33,16 +35,24 @@ def control_channel(island_api_client) -> ControlChannel:
return ControlChannel(SERVER, AGENT_ID, island_api_client) return ControlChannel(SERVER, AGENT_ID, island_api_client)
def test_control_channel__should_agent_stop(control_channel, island_api_client): @pytest.mark.parametrize("signal_time,expected_should_stop", [(1663950115, True), (None, False)])
control_channel.should_agent_stop() def test_control_channel__should_agent_stop(
assert island_api_client.should_agent_stop.called_once() control_channel: IControlChannel,
island_api_client: IIslandAPIClient,
signal_time: Optional[int],
expected_should_stop: bool,
):
island_api_client.get_agent_signals = MagicMock(
return_value=AgentSignals(terminate=signal_time)
)
assert control_channel.should_agent_stop() is expected_should_stop
@pytest.mark.parametrize("api_error", CONTROL_CHANNEL_API_ERRORS) @pytest.mark.parametrize("api_error", CONTROL_CHANNEL_API_ERRORS)
def test_control_channel__should_agent_stop_raises_error( def test_control_channel__should_agent_stop_raises_error(
control_channel, island_api_client, api_error control_channel, island_api_client, api_error
): ):
island_api_client.should_agent_stop.side_effect = api_error() island_api_client.get_agent_signals.side_effect = api_error()
with pytest.raises(IslandCommunicationError): with pytest.raises(IslandCommunicationError):
control_channel.should_agent_stop() control_channel.should_agent_stop()

View File

@ -3,7 +3,7 @@ import uuid
import pytest import pytest
from monkey_island.cc.models.monkey import Monkey, MonkeyNotFoundError, ParentNotFoundError from monkey_island.cc.models.monkey import Monkey, MonkeyNotFoundError
from monkey_island.cc.models.monkey_ttl import MonkeyTtl from monkey_island.cc.models.monkey_ttl import MonkeyTtl
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -162,35 +162,3 @@ class TestMonkey:
cache_info_after_query = Monkey.is_monkey.storage.backend.cache_info() cache_info_after_query = Monkey.is_monkey.storage.backend.cache_info()
assert cache_info_after_query.hits == 2 assert cache_info_after_query.hits == 2
@pytest.mark.usefixtures("uses_database")
def test_has_parent(self):
monkey_1 = Monkey(guid=str(uuid.uuid4()))
monkey_2 = Monkey(guid=str(uuid.uuid4()))
monkey_1.parent = [[monkey_2.guid]]
monkey_1.save()
assert monkey_1.has_parent()
@pytest.mark.usefixtures("uses_database")
def test_has_no_parent(self):
monkey_1 = Monkey(guid=str(uuid.uuid4()))
monkey_1.parent = [[monkey_1.guid]]
monkey_1.save()
assert not monkey_1.has_parent()
@pytest.mark.usefixtures("uses_database")
def test_get_parent(self):
monkey_1 = Monkey(guid=str(uuid.uuid4()))
monkey_2 = Monkey(guid=str(uuid.uuid4()))
monkey_1.parent = [[monkey_2.guid]]
monkey_1.save()
monkey_2.save()
assert monkey_1.get_parent().guid == monkey_2.guid
@pytest.mark.usefixtures("uses_database")
def test_get_parent_no_parent(self):
monkey_1 = Monkey(guid=str(uuid.uuid4()))
monkey_1.parent = [[monkey_1.guid]]
monkey_1.save()
with pytest.raises(ParentNotFoundError):
monkey_1.get_parent()

View File

@ -17,14 +17,29 @@ from monkey_island.cc.repository import (
) )
VICTIM_ZERO_ID = uuid4() VICTIM_ZERO_ID = uuid4()
VICTIM_TWO_ID = uuid4()
VICTIM_THREE_ID = uuid4()
PROGENITOR_AGENT = Agent(
id=VICTIM_ZERO_ID, machine_id=1, start_time=datetime.fromtimestamp(1661856718)
)
DESCENDANT_AGENT = Agent(
id=VICTIM_THREE_ID,
machine_id=4,
start_time=datetime.fromtimestamp(1661856868),
parent_id=VICTIM_TWO_ID,
)
RUNNING_AGENTS = ( RUNNING_AGENTS = (
Agent(id=VICTIM_ZERO_ID, machine_id=1, start_time=datetime.fromtimestamp(1661856718)), PROGENITOR_AGENT,
Agent( Agent(
id=uuid4(), id=VICTIM_TWO_ID,
machine_id=2, machine_id=2,
start_time=datetime.fromtimestamp(1661856818), start_time=datetime.fromtimestamp(1661856818),
parent_id=VICTIM_ZERO_ID, parent_id=VICTIM_ZERO_ID,
), ),
DESCENDANT_AGENT,
) )
STOPPED_AGENTS = ( STOPPED_AGENTS = (
Agent( Agent(
@ -172,6 +187,24 @@ def test_get_running_agents__retrieval_error(error_raising_agent_repository):
error_raising_agent_repository.get_running_agents() error_raising_agent_repository.get_running_agents()
@pytest.mark.parametrize("agent", [DESCENDANT_AGENT, PROGENITOR_AGENT])
def test_get_progenitor(agent_repository, agent):
actual_progenitor = agent_repository.get_progenitor(agent)
assert actual_progenitor == PROGENITOR_AGENT
def test_get_progenitor__id_not_found(agent_repository):
dummy_agent = Agent(id=uuid4(), machine_id=10, start_time=datetime.now(), parent_id=uuid4())
with pytest.raises(UnknownRecordError):
agent_repository.get_progenitor(dummy_agent)
def test_get_progenitor__retrieval_error(error_raising_agent_repository):
with pytest.raises(RetrievalError):
error_raising_agent_repository.get_progenitor(AGENTS[1])
def test_reset(agent_repository): def test_reset(agent_repository):
# Ensure the repository is not empty # Ensure the repository is not empty
for agent in AGENTS: for agent in AGENTS:

View File

@ -0,0 +1,62 @@
from http import HTTPStatus
from unittest.mock import MagicMock
from uuid import UUID
import pytest
from tests.common import StubDIContainer
from common.agent_signals import AgentSignals as Signals
from monkey_island.cc.repository import RetrievalError, StorageError
from monkey_island.cc.services import AgentSignalsService
TIMESTAMP_1 = 123456789
TIMESTAMP_2 = 123546789
SIGNALS_1 = Signals(terminate=TIMESTAMP_1)
SIGNALS_2 = Signals(terminate=TIMESTAMP_2)
AGENT_ID_1 = UUID("c0dd10b3-e21a-4da9-9d96-a99c19ebd7c5")
AGENT_ID_2 = UUID("9b4279f6-6ec5-4953-821e-893ddc71a988")
@pytest.fixture
def mock_agent_signals_service():
return MagicMock(spec=AgentSignalsService)
@pytest.fixture
def flask_client(build_flask_client, mock_agent_signals_service):
container = StubDIContainer()
container.register_instance(AgentSignalsService, mock_agent_signals_service)
with build_flask_client(container) as flask_client:
yield flask_client
@pytest.mark.parametrize(
"url, signals",
[
(f"/api/agent-signals/{AGENT_ID_1}", SIGNALS_1),
(f"/api/agent-signals/{AGENT_ID_2}", SIGNALS_2),
],
)
def test_agent_signals_get(flask_client, mock_agent_signals_service, url, signals):
mock_agent_signals_service.get_signals.return_value = signals
resp = flask_client.get(url, follow_redirects=True)
assert resp.status_code == HTTPStatus.OK
assert resp.json == signals.dict(simplify=True)
@pytest.mark.parametrize(
"url, error",
[
(f"/api/agent-signals/{AGENT_ID_1}", RetrievalError),
(f"/api/agent-signals/{AGENT_ID_2}", StorageError),
],
)
def test_agent_signals_get__internal_server_error(
flask_client, mock_agent_signals_service, url, error
):
mock_agent_signals_service.get_signals.side_effect = error
resp = flask_client.get(url, follow_redirects=True)
assert resp.status_code == HTTPStatus.INTERNAL_SERVER_ERROR

View File

@ -0,0 +1,51 @@
from http import HTTPStatus
from unittest.mock import MagicMock
import pytest
from tests.common import StubDIContainer
from monkey_island.cc.event_queue import IIslandEventQueue
from monkey_island.cc.resources import TerminateAllAgents
TIMESTAMP = 123456789
@pytest.fixture
def flask_client(build_flask_client):
container = StubDIContainer()
mock_island_event_queue = MagicMock(spec=IIslandEventQueue)
mock_island_event_queue.publish.side_effect = None
container.register_instance(IIslandEventQueue, mock_island_event_queue)
with build_flask_client(container) as flask_client:
yield flask_client
def test_terminate_all_agents_post(flask_client):
resp = flask_client.post(
TerminateAllAgents.urls[0],
json={"terminate_time": TIMESTAMP},
follow_redirects=True,
)
assert resp.status_code == HTTPStatus.NO_CONTENT
@pytest.mark.parametrize(
"bad_data",
[
"bad timestamp",
{},
{"wrong_key": TIMESTAMP},
TIMESTAMP,
{"terminate_time": 0},
{"terminate_time": -1},
],
)
def test_terminate_all_agents_post__invalid_timestamp(flask_client, bad_data):
resp = flask_client.post(
TerminateAllAgents.urls[0],
json=bad_data,
follow_redirects=True,
)
assert resp.status_code == HTTPStatus.BAD_REQUEST

View File

@ -0,0 +1,144 @@
from unittest.mock import MagicMock
from uuid import UUID
import pytest
from common.types import AgentID
from monkey_island.cc.models import Agent, IslandMode, Simulation
from monkey_island.cc.repository import IAgentRepository, ISimulationRepository, UnknownRecordError
from monkey_island.cc.services import AgentSignalsService
AGENT_1 = Agent(
id=UUID("f811ad00-5a68-4437-bd51-7b5cc1768ad5"),
machine_id=1,
start_time=100,
parent_id=None,
)
AGENT_2 = Agent(
id=UUID("012e7238-7b81-4108-8c7f-0787bc3f3c10"),
machine_id=2,
start_time=200,
parent_id=AGENT_1.id,
)
AGENT_3 = Agent(
id=UUID("0fc9afcb-1902-436b-bd5c-1ad194252484"),
machine_id=3,
start_time=300,
parent_id=AGENT_2.id,
)
AGENTS = [AGENT_1, AGENT_2, AGENT_3]
@pytest.fixture
def mock_simulation_repository() -> IAgentRepository:
return MagicMock(spec=ISimulationRepository)
@pytest.fixture(scope="session")
def mock_agent_repository() -> IAgentRepository:
def get_agent_by_id(agent_id: AgentID) -> Agent:
for agent in AGENTS:
if agent.id == agent_id:
return agent
raise UnknownRecordError(str(agent_id))
agent_repository = MagicMock(spec=IAgentRepository)
agent_repository.get_progenitor = MagicMock(return_value=AGENT_1)
agent_repository.get_agent_by_id = MagicMock(side_effect=get_agent_by_id)
return agent_repository
@pytest.fixture
def agent_signals_service(mock_simulation_repository, mock_agent_repository) -> AgentSignalsService:
return AgentSignalsService(mock_simulation_repository, mock_agent_repository)
@pytest.mark.parametrize("agent", AGENTS)
def test_terminate_is_none(
agent,
agent_signals_service: AgentSignalsService,
mock_simulation_repository: ISimulationRepository,
):
mock_simulation_repository.get_simulation = MagicMock(
return_value=Simulation(terminate_signal_time=None)
)
signals = agent_signals_service.get_signals(agent.id)
assert signals.terminate is None
@pytest.mark.parametrize("agent", AGENTS)
def test_agent_started_before_terminate(
agent,
agent_signals_service: AgentSignalsService,
mock_simulation_repository: ISimulationRepository,
):
TERMINATE_TIMESTAMP = 400
mock_simulation_repository.get_simulation = MagicMock(
return_value=Simulation(terminate_signal_time=TERMINATE_TIMESTAMP)
)
signals = agent_signals_service.get_signals(agent.id)
assert signals.terminate.timestamp() == TERMINATE_TIMESTAMP
@pytest.mark.parametrize("agent", AGENTS)
def test_agent_started_after_terminate(
agent,
agent_signals_service: AgentSignalsService,
mock_simulation_repository: ISimulationRepository,
):
TERMINATE_TIMESTAMP = 50
mock_simulation_repository.get_simulation = MagicMock(
return_value=Simulation(terminate_signal_time=TERMINATE_TIMESTAMP)
)
signals = agent_signals_service.get_signals(agent.id)
assert signals.terminate is None
@pytest.mark.parametrize("agent", AGENTS)
def test_progenitor_started_before_terminate(
agent,
agent_signals_service: AgentSignalsService,
mock_simulation_repository: ISimulationRepository,
):
TERMINATE_TIMESTAMP = 150
mock_simulation_repository.get_simulation = MagicMock(
return_value=Simulation(terminate_signal_time=TERMINATE_TIMESTAMP)
)
signals = agent_signals_service.get_signals(agent.id)
assert signals.terminate.timestamp() == TERMINATE_TIMESTAMP
def test_on_terminate_agents_signal__stores_timestamp(
agent_signals_service: AgentSignalsService, mock_simulation_repository: ISimulationRepository
):
timestamp = 100
mock_simulation_repository.get_simulation = MagicMock(return_value=Simulation())
agent_signals_service.on_terminate_agents_signal(timestamp)
expected_value = Simulation(terminate_signal_time=timestamp)
assert mock_simulation_repository.save_simulation.called_once_with(expected_value)
def test_on_terminate_agents_signal__updates_timestamp(
agent_signals_service: AgentSignalsService, mock_simulation_repository: ISimulationRepository
):
timestamp = 100
mock_simulation_repository.get_simulation = MagicMock(
return_value=Simulation(mode=IslandMode.RANSOMWARE, terminate_signal_time=50)
)
agent_signals_service.on_terminate_agents_signal(timestamp)
expected_value = Simulation(mode=IslandMode.RANSOMWARE, terminate_signal_time=timestamp)
assert mock_simulation_repository.save_simulation.called_once_with(expected_value)

View File

@ -1,26 +1,6 @@
import uuid import uuid
import pytest
from monkey_island.cc.models import Config, Monkey from monkey_island.cc.models import Config, Monkey
from monkey_island.cc.models.agent_controls import AgentControls
from monkey_island.cc.services.infection_lifecycle import should_agent_die
@pytest.mark.usefixtures("uses_database")
def test_should_agent_die_by_config(monkeypatch):
monkey = Monkey(guid=str(uuid.uuid4()))
monkey.config = Config()
monkey.should_stop = True
monkey.save()
assert should_agent_die(monkey.guid)
monkeypatch.setattr(
"monkey_island.cc.services.infection_lifecycle._is_monkey_killed_manually", lambda _: False
)
monkey.should_stop = True
monkey.save()
assert not should_agent_die(monkey.guid)
def create_monkey(launch_time): def create_monkey(launch_time):
@ -32,80 +12,9 @@ def create_monkey(launch_time):
return monkey return monkey
@pytest.mark.usefixtures("uses_database")
def test_should_agent_die_no_kill_event():
monkey = create_monkey(launch_time=3)
kill_event = AgentControls()
kill_event.save()
assert not should_agent_die(monkey.guid)
def create_kill_event(event_time):
kill_event = AgentControls(last_stop_all=event_time)
kill_event.save()
return kill_event
def create_parent(child_monkey, launch_time): def create_parent(child_monkey, launch_time):
monkey_parent = Monkey(guid=str(uuid.uuid4())) monkey_parent = Monkey(guid=str(uuid.uuid4()))
child_monkey.parent = [[monkey_parent.guid]] child_monkey.parent = [[monkey_parent.guid]]
monkey_parent.launch_time = launch_time monkey_parent.launch_time = launch_time
monkey_parent.save() monkey_parent.save()
child_monkey.save() child_monkey.save()
@pytest.mark.usefixtures("uses_database")
def test_was_agent_killed_manually(monkeypatch):
monkey = create_monkey(launch_time=2)
create_kill_event(event_time=3)
assert should_agent_die(monkey.guid)
@pytest.mark.usefixtures("uses_database")
def test_agent_killed_on_wakeup(monkeypatch):
monkey = create_monkey(launch_time=2)
create_kill_event(event_time=2)
assert should_agent_die(monkey.guid)
@pytest.mark.usefixtures("uses_database")
def test_manual_kill_dont_affect_new_monkeys(monkeypatch):
monkey = create_monkey(launch_time=3)
create_kill_event(event_time=2)
assert not should_agent_die(monkey.guid)
@pytest.mark.usefixtures("uses_database")
def test_parent_manually_killed(monkeypatch):
monkey = create_monkey(launch_time=3)
create_parent(child_monkey=monkey, launch_time=1)
create_kill_event(event_time=2)
assert should_agent_die(monkey.guid)
@pytest.mark.usefixtures("uses_database")
def test_parent_manually_killed_on_wakeup(monkeypatch):
monkey = create_monkey(launch_time=3)
create_parent(child_monkey=monkey, launch_time=2)
create_kill_event(event_time=2)
assert should_agent_die(monkey.guid)
@pytest.mark.usefixtures("uses_database")
def test_manual_kill_dont_affect_new_monkeys_with_parent(monkeypatch):
monkey = create_monkey(launch_time=3)
create_parent(child_monkey=monkey, launch_time=2)
create_kill_event(event_time=1)
assert not should_agent_die(monkey.guid)

View File

@ -27,6 +27,7 @@ from monkey_island.cc.repository.i_simulation_repository import ISimulationRepos
from monkey_island.cc.repository.ICredentials import ICredentialsRepository from monkey_island.cc.repository.ICredentials import ICredentialsRepository
from monkey_island.cc.repository.zero_trust.IEventRepository import IEventRepository from monkey_island.cc.repository.zero_trust.IEventRepository import IEventRepository
from monkey_island.cc.repository.zero_trust.IFindingRepository import IFindingRepository from monkey_island.cc.repository.zero_trust.IFindingRepository import IFindingRepository
from monkey_island.cc.services import AgentSignalsService
fake_monkey_dir_path # unused variable (monkey/tests/infection_monkey/post_breach/actions/test_users_custom_pba.py:37) fake_monkey_dir_path # unused variable (monkey/tests/infection_monkey/post_breach/actions/test_users_custom_pba.py:37)
set_os_linux # unused variable (monkey/tests/infection_monkey/post_breach/actions/test_users_custom_pba.py:37) set_os_linux # unused variable (monkey/tests/infection_monkey/post_breach/actions/test_users_custom_pba.py:37)