forked from p34709852/monkey
Merge pull request #2349 from guardicore/2261-refactor-manual-agent-logic
2261 refactor manual agent logic
This commit is contained in:
commit
dbaa56c39d
|
@ -24,6 +24,7 @@ Changelog](https://keepachangelog.com/en/1.0.0/).
|
||||||
- The ability to customize the file extension used by ransomware when
|
- The ability to customize the file extension used by ransomware when
|
||||||
encrypting files. #1242
|
encrypting files. #1242
|
||||||
- `/api/agents` endpoint.
|
- `/api/agents` endpoint.
|
||||||
|
- `/api/agent-signals` endpoint. #2261
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
- Reset workflow. Now it's possible to delete data gathered by agents without
|
- Reset workflow. Now it's possible to delete data gathered by agents without
|
||||||
|
@ -64,6 +65,7 @@ Changelog](https://keepachangelog.com/en/1.0.0/).
|
||||||
- Tunneling to relays to provide better firewall evasion, faster Island
|
- Tunneling to relays to provide better firewall evasion, faster Island
|
||||||
connection times, unlimited hops, and a more resilient way for agents to call
|
connection times, unlimited hops, and a more resilient way for agents to call
|
||||||
home. #2216, #1583
|
home. #2216, #1583
|
||||||
|
- "/api/monkey-control/stop-all-agents" to "/api/agent-signals/terminate-all-agents" #2261
|
||||||
|
|
||||||
### Removed
|
### Removed
|
||||||
- VSFTPD exploiter. #1533
|
- VSFTPD exploiter. #1533
|
||||||
|
@ -109,6 +111,7 @@ Changelog](https://keepachangelog.com/en/1.0.0/).
|
||||||
- "/api/configuration/export" endpoint. #2002
|
- "/api/configuration/export" endpoint. #2002
|
||||||
- "/api/island-configuration" endpoint. #2003
|
- "/api/island-configuration" endpoint. #2003
|
||||||
- "-t/--tunnel" from agent command line arguments. #2216
|
- "-t/--tunnel" from agent command line arguments. #2216
|
||||||
|
- "/api/monkey-control/neets-to-stop". #2261
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
- A bug in network map page that caused delay of telemetry log loading. #1545
|
- A bug in network map page that caused delay of telemetry log loading. #1545
|
||||||
|
|
|
@ -88,8 +88,9 @@ class MonkeyIslandClient(object):
|
||||||
|
|
||||||
@avoid_race_condition
|
@avoid_race_condition
|
||||||
def kill_all_monkeys(self):
|
def kill_all_monkeys(self):
|
||||||
|
# TODO change this request, because monkey-control resource got removed
|
||||||
response = self.requests.post_json(
|
response = self.requests.post_json(
|
||||||
"api/monkey-control/stop-all-agents", json={"kill_time": time.time()}
|
"api/agent-signals/terminate-all-agents", json={"terminate_time": time.time()}
|
||||||
)
|
)
|
||||||
if response.ok:
|
if response.ok:
|
||||||
LOGGER.info("Killing all monkeys after the test.")
|
LOGGER.info("Killing all monkeys after the test.")
|
||||||
|
|
|
@ -7,3 +7,4 @@ from .operating_system import OperatingSystem
|
||||||
from . import types
|
from . import types
|
||||||
from . import base_models
|
from . import base_models
|
||||||
from .agent_registration_data import AgentRegistrationData
|
from .agent_registration_data import AgentRegistrationData
|
||||||
|
from .agent_signals import AgentSignals
|
||||||
|
|
|
@ -0,0 +1,8 @@
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from .base_models import InfectionMonkeyBaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class AgentSignals(InfectionMonkeyBaseModel):
|
||||||
|
terminate: Optional[datetime]
|
|
@ -6,7 +6,7 @@ from typing import List, Sequence
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from common import AgentRegistrationData, OperatingSystem
|
from common import AgentRegistrationData, AgentSignals, OperatingSystem
|
||||||
from common.agent_configuration import AgentConfiguration
|
from common.agent_configuration import AgentConfiguration
|
||||||
from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable
|
from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable
|
||||||
from common.agent_events import AbstractAgentEvent
|
from common.agent_events import AbstractAgentEvent
|
||||||
|
@ -146,19 +146,6 @@ class HTTPIslandAPIClient(IIslandAPIClient):
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
@handle_island_errors
|
|
||||||
@convert_json_error_to_island_api_error
|
|
||||||
def should_agent_stop(self, agent_id: str) -> bool:
|
|
||||||
url = f"{self._api_url}/monkey-control/needs-to-stop/{agent_id}"
|
|
||||||
response = requests.get( # noqa: DUO123
|
|
||||||
url,
|
|
||||||
verify=False,
|
|
||||||
timeout=SHORT_REQUEST_TIMEOUT,
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
return response.json()["stop_agent"]
|
|
||||||
|
|
||||||
@handle_island_errors
|
@handle_island_errors
|
||||||
@convert_json_error_to_island_api_error
|
@convert_json_error_to_island_api_error
|
||||||
def get_config(self) -> AgentConfiguration:
|
def get_config(self) -> AgentConfiguration:
|
||||||
|
@ -199,6 +186,18 @@ class HTTPIslandAPIClient(IIslandAPIClient):
|
||||||
|
|
||||||
return serialized_events
|
return serialized_events
|
||||||
|
|
||||||
|
@handle_island_errors
|
||||||
|
@convert_json_error_to_island_api_error
|
||||||
|
def get_agent_signals(self, agent_id: str) -> AgentSignals:
|
||||||
|
url = f"{self._api_url}/agent-signals/{agent_id}"
|
||||||
|
response = requests.get( # noqa: DUO123
|
||||||
|
url,
|
||||||
|
verify=False,
|
||||||
|
timeout=SHORT_REQUEST_TIMEOUT,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return AgentSignals(**response.json())
|
||||||
|
|
||||||
|
|
||||||
class HTTPIslandAPIClientFactory(AbstractIslandAPIClientFactory):
|
class HTTPIslandAPIClientFactory(AbstractIslandAPIClientFactory):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
from common import AgentRegistrationData, OperatingSystem
|
from common import AgentRegistrationData, AgentSignals, OperatingSystem
|
||||||
from common.agent_configuration import AgentConfiguration
|
from common.agent_configuration import AgentConfiguration
|
||||||
from common.agent_events import AbstractAgentEvent
|
from common.agent_events import AbstractAgentEvent
|
||||||
from common.credentials import Credentials
|
from common.credentials import Credentials
|
||||||
|
@ -107,19 +107,6 @@ class IIslandAPIClient(ABC):
|
||||||
:raises IslandAPITimeoutError: If the command timed out
|
:raises IslandAPITimeoutError: If the command timed out
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def should_agent_stop(self, agent_id: str) -> bool:
|
|
||||||
"""
|
|
||||||
Check with the island to see if the agent should stop
|
|
||||||
|
|
||||||
:param agent_id: The agent identifier for the agent to check
|
|
||||||
:raises IslandAPIConnectionError: If the client could not connect to the island
|
|
||||||
:raises IslandAPIRequestError: If there was a problem with the client request
|
|
||||||
:raises IslandAPIRequestFailedError: If the server experienced an error
|
|
||||||
:raises IslandAPITimeoutError: If the command timed out
|
|
||||||
:return: True if the agent should stop, otherwise False
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_config(self) -> AgentConfiguration:
|
def get_config(self) -> AgentConfiguration:
|
||||||
"""
|
"""
|
||||||
|
@ -143,3 +130,16 @@ class IIslandAPIClient(ABC):
|
||||||
:raises IslandAPITimeoutError: If the command timed out
|
:raises IslandAPITimeoutError: If the command timed out
|
||||||
:return: Credentials
|
:return: Credentials
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_agent_signals(self, agent_id: str) -> AgentSignals:
|
||||||
|
"""
|
||||||
|
Gets an agent's signals from the island
|
||||||
|
|
||||||
|
:param agent_id: ID of the agent whose signals should be retrieved
|
||||||
|
:raises IslandAPIConnectionError: If the client could not connect to the island
|
||||||
|
:raises IslandAPIRequestError: If there was a problem with the client request
|
||||||
|
:raises IslandAPIRequestFailedError: If the server experienced an error
|
||||||
|
:raises IslandAPITimeoutError: If the command timed out
|
||||||
|
:return: The relevant agent's signals
|
||||||
|
"""
|
||||||
|
|
|
@ -36,7 +36,8 @@ class ControlChannel(IControlChannel):
|
||||||
if not self._control_channel_server:
|
if not self._control_channel_server:
|
||||||
logger.error("Agent should stop because it can't connect to the C&C server.")
|
logger.error("Agent should stop because it can't connect to the C&C server.")
|
||||||
return True
|
return True
|
||||||
return self._island_api_client.should_agent_stop(self._agent_id)
|
agent_signals = self._island_api_client.get_agent_signals(self._agent_id)
|
||||||
|
return agent_signals.terminate is not None
|
||||||
|
|
||||||
@handle_island_api_errors
|
@handle_island_api_errors
|
||||||
def get_config(self) -> AgentConfiguration:
|
def get_config(self) -> AgentConfiguration:
|
||||||
|
|
|
@ -123,7 +123,7 @@ class InfectionMonkey:
|
||||||
self._control_client = ControlClient(
|
self._control_client = ControlClient(
|
||||||
server_address=server, island_api_client=self._island_api_client
|
server_address=server, island_api_client=self._island_api_client
|
||||||
)
|
)
|
||||||
self._control_channel = ControlChannel(server, GUID, self._island_api_client)
|
self._control_channel = ControlChannel(server, get_agent_id(), self._island_api_client)
|
||||||
self._register_agent(server)
|
self._register_agent(server)
|
||||||
|
|
||||||
# TODO Refactor the telemetry messengers to accept control client
|
# TODO Refactor the telemetry messengers to accept control client
|
||||||
|
|
|
@ -15,6 +15,7 @@ from monkey_island.cc.resources import (
|
||||||
AgentConfiguration,
|
AgentConfiguration,
|
||||||
AgentEvents,
|
AgentEvents,
|
||||||
Agents,
|
Agents,
|
||||||
|
AgentSignals,
|
||||||
ClearSimulationData,
|
ClearSimulationData,
|
||||||
IPAddresses,
|
IPAddresses,
|
||||||
IslandLog,
|
IslandLog,
|
||||||
|
@ -23,9 +24,9 @@ from monkey_island.cc.resources import (
|
||||||
PropagationCredentials,
|
PropagationCredentials,
|
||||||
RemoteRun,
|
RemoteRun,
|
||||||
ResetAgentConfiguration,
|
ResetAgentConfiguration,
|
||||||
|
TerminateAllAgents,
|
||||||
)
|
)
|
||||||
from monkey_island.cc.resources.AbstractResource import AbstractResource
|
from monkey_island.cc.resources.AbstractResource import AbstractResource
|
||||||
from monkey_island.cc.resources.agent_controls import StopAgentCheck, StopAllAgents
|
|
||||||
from monkey_island.cc.resources.attack.attack_report import AttackReport
|
from monkey_island.cc.resources.attack.attack_report import AttackReport
|
||||||
from monkey_island.cc.resources.auth import Authenticate, Register, RegistrationStatus, init_jwt
|
from monkey_island.cc.resources.auth import Authenticate, Register, RegistrationStatus, init_jwt
|
||||||
from monkey_island.cc.resources.blackbox.log_blackbox_endpoint import LogBlackboxEndpoint
|
from monkey_island.cc.resources.blackbox.log_blackbox_endpoint import LogBlackboxEndpoint
|
||||||
|
@ -188,6 +189,7 @@ def init_restful_endpoints(api: FlaskDIWrapper):
|
||||||
api.add_resource(IPAddresses)
|
api.add_resource(IPAddresses)
|
||||||
|
|
||||||
api.add_resource(AgentEvents)
|
api.add_resource(AgentEvents)
|
||||||
|
api.add_resource(AgentSignals)
|
||||||
|
|
||||||
# API Spec: These two should be the same resource, GET for download and POST for upload
|
# API Spec: These two should be the same resource, GET for download and POST for upload
|
||||||
api.add_resource(PBAFileDownload)
|
api.add_resource(PBAFileDownload)
|
||||||
|
@ -196,8 +198,6 @@ def init_restful_endpoints(api: FlaskDIWrapper):
|
||||||
api.add_resource(PropagationCredentials)
|
api.add_resource(PropagationCredentials)
|
||||||
api.add_resource(RemoteRun)
|
api.add_resource(RemoteRun)
|
||||||
api.add_resource(Version)
|
api.add_resource(Version)
|
||||||
api.add_resource(StopAgentCheck)
|
|
||||||
api.add_resource(StopAllAgents)
|
|
||||||
|
|
||||||
# Resources used by black box tests
|
# Resources used by black box tests
|
||||||
# API Spec: Fix all the following endpoints, see comments in the resource classes
|
# API Spec: Fix all the following endpoints, see comments in the resource classes
|
||||||
|
@ -211,6 +211,7 @@ def init_restful_endpoints(api: FlaskDIWrapper):
|
||||||
def init_rpc_endpoints(api: FlaskDIWrapper):
|
def init_rpc_endpoints(api: FlaskDIWrapper):
|
||||||
api.add_resource(ResetAgentConfiguration)
|
api.add_resource(ResetAgentConfiguration)
|
||||||
api.add_resource(ClearSimulationData)
|
api.add_resource(ClearSimulationData)
|
||||||
|
api.add_resource(TerminateAllAgents)
|
||||||
|
|
||||||
|
|
||||||
def init_app(mongo_url: str, container: DIContainer):
|
def init_app(mongo_url: str, container: DIContainer):
|
||||||
|
|
|
@ -9,6 +9,7 @@ class IslandEventTopic(Enum):
|
||||||
CLEAR_SIMULATION_DATA = auto()
|
CLEAR_SIMULATION_DATA = auto()
|
||||||
RESET_AGENT_CONFIGURATION = auto()
|
RESET_AGENT_CONFIGURATION = auto()
|
||||||
SET_ISLAND_MODE = auto()
|
SET_ISLAND_MODE = auto()
|
||||||
|
TERMINATE_AGENTS = auto()
|
||||||
|
|
||||||
|
|
||||||
class IIslandEventQueue(ABC):
|
class IIslandEventQueue(ABC):
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
from .agent_controls import AgentControls
|
|
|
@ -1,8 +0,0 @@
|
||||||
from mongoengine import Document, FloatField
|
|
||||||
|
|
||||||
|
|
||||||
# TODO rename to Simulation, add other metadata
|
|
||||||
class AgentControls(Document):
|
|
||||||
|
|
||||||
# Timestamp of the last "kill all agents" command
|
|
||||||
last_stop_all = FloatField(default=None)
|
|
|
@ -22,10 +22,6 @@ from monkey_island.cc.models.monkey_ttl import MonkeyTtl, create_monkey_ttl_docu
|
||||||
from monkey_island.cc.server_utils.consts import DEFAULT_MONKEY_TTL_EXPIRY_DURATION_IN_SECONDS
|
from monkey_island.cc.server_utils.consts import DEFAULT_MONKEY_TTL_EXPIRY_DURATION_IN_SECONDS
|
||||||
|
|
||||||
|
|
||||||
class ParentNotFoundError(Exception):
|
|
||||||
"""Raise when trying to get a parent of monkey that doesn't have one"""
|
|
||||||
|
|
||||||
|
|
||||||
class Monkey(Document):
|
class Monkey(Document):
|
||||||
"""
|
"""
|
||||||
This class has 2 main section:
|
This class has 2 main section:
|
||||||
|
@ -98,18 +94,6 @@ class Monkey(Document):
|
||||||
monkey_is_dead = True
|
monkey_is_dead = True
|
||||||
return monkey_is_dead
|
return monkey_is_dead
|
||||||
|
|
||||||
def has_parent(self):
|
|
||||||
for p in self.parent:
|
|
||||||
if p[0] != self.guid:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_parent(self):
|
|
||||||
if self.has_parent():
|
|
||||||
return Monkey.objects(guid=self.parent[0][0]).first()
|
|
||||||
else:
|
|
||||||
raise ParentNotFoundError(f"No parent was found for agent with GUID {self.guid}")
|
|
||||||
|
|
||||||
def get_os(self):
|
def get_os(self):
|
||||||
os = "unknown"
|
os = "unknown"
|
||||||
if self.description.lower().find("linux") != -1:
|
if self.description.lower().find("linux") != -1:
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from common.base_models import InfectionMonkeyBaseModel
|
from common.base_models import InfectionMonkeyBaseModel
|
||||||
|
|
||||||
|
@ -13,3 +15,4 @@ class IslandMode(Enum):
|
||||||
|
|
||||||
class Simulation(InfectionMonkeyBaseModel):
|
class Simulation(InfectionMonkeyBaseModel):
|
||||||
mode: IslandMode = IslandMode.UNSET
|
mode: IslandMode = IslandMode.UNSET
|
||||||
|
terminate_signal_time: Optional[datetime] = None
|
||||||
|
|
|
@ -16,7 +16,7 @@ class IAgentRepository(ABC):
|
||||||
already exists, update it.
|
already exists, update it.
|
||||||
|
|
||||||
:param agent: The `agent` to be inserted or updated
|
:param agent: The `agent` to be inserted or updated
|
||||||
:raises StorageError: If an error occurred while attempting to store the `Agent`
|
:raises StorageError: If an error occurs while attempting to store the `Agent`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -28,7 +28,7 @@ class IAgentRepository(ABC):
|
||||||
:return: An `Agent` with a matching `id`
|
:return: An `Agent` with a matching `id`
|
||||||
:raises UnknownRecordError: If an `Agent` with the specified `id` does not exist in the
|
:raises UnknownRecordError: If an `Agent` with the specified `id` does not exist in the
|
||||||
repository
|
repository
|
||||||
:raises RetrievalError: If an error occurred while attempting to retrieve the `Agent`
|
:raises RetrievalError: If an error occurs while attempting to retrieve the `Agent`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -37,7 +37,18 @@ class IAgentRepository(ABC):
|
||||||
Get all `Agents` that are currently running
|
Get all `Agents` that are currently running
|
||||||
|
|
||||||
:return: All `Agents` that are currently running
|
:return: All `Agents` that are currently running
|
||||||
:raises RetrievalError: If an error occurred while attempting to retrieve the `Agents`
|
:raises RetrievalError: If an error occurs while attempting to retrieve the `Agents`
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_progenitor(self, agent: Agent) -> Agent:
|
||||||
|
"""
|
||||||
|
Gets the progenitor `Agent` for the agent.
|
||||||
|
|
||||||
|
:param agent: The Agent for which we want the progenitor
|
||||||
|
:return: `Agent` progenitor ( an initial agent that started the exploitation chain)
|
||||||
|
:raises RetrievalError: If an error occurrs while attempting to retrieve the `Agent`
|
||||||
|
:raises UnknownRecordError: If the agent ID is not in the repository
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
|
@ -26,7 +26,7 @@ class IMachineRepository(ABC):
|
||||||
`Machine` already exists, update it.
|
`Machine` already exists, update it.
|
||||||
|
|
||||||
:param machine: The `Machine` to be inserted or updated
|
:param machine: The `Machine` to be inserted or updated
|
||||||
:raises StorageError: If an error occurred while attempting to store the `Machine`
|
:raises StorageError: If an error occurs while attempting to store the `Machine`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -38,7 +38,7 @@ class IMachineRepository(ABC):
|
||||||
:return: A `Machine` with a matching `id`
|
:return: A `Machine` with a matching `id`
|
||||||
:raises UnknownRecordError: If a `Machine` with the specified `id` does not exist in the
|
:raises UnknownRecordError: If a `Machine` with the specified `id` does not exist in the
|
||||||
repository
|
repository
|
||||||
:raises RetrievalError: If an error occurred while attempting to retrieve the `Machine`
|
:raises RetrievalError: If an error occurs while attempting to retrieve the `Machine`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -50,7 +50,7 @@ class IMachineRepository(ABC):
|
||||||
:return: A `Machine` with a matching `hardware_id`
|
:return: A `Machine` with a matching `hardware_id`
|
||||||
:raises UnknownRecordError: If a `Machine` with the specified `hardware_id` does not exist
|
:raises UnknownRecordError: If a `Machine` with the specified `hardware_id` does not exist
|
||||||
in the repository
|
in the repository
|
||||||
:raises RetrievalError: If an error occurred while attempting to retrieve the `Machine`
|
:raises RetrievalError: If an error occurs while attempting to retrieve the `Machine`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -62,7 +62,7 @@ class IMachineRepository(ABC):
|
||||||
:return: A sequence of Machines that have a network interface with a matching IP
|
:return: A sequence of Machines that have a network interface with a matching IP
|
||||||
:raises UnknownRecordError: If a `Machine` with the specified `ip` does not exist in the
|
:raises UnknownRecordError: If a `Machine` with the specified `ip` does not exist in the
|
||||||
repository
|
repository
|
||||||
:raises RetrievalError: If an error occurred while attempting to retrieve the `Machine`
|
:raises RetrievalError: If an error occurs while attempting to retrieve the `Machine`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -70,6 +70,6 @@ class IMachineRepository(ABC):
|
||||||
"""
|
"""
|
||||||
Removes all data from the repository
|
Removes all data from the repository
|
||||||
|
|
||||||
:raises RemovalError: If an error occurred while attempting to remove all `Machines` from
|
:raises RemovalError: If an error occurs while attempting to remove all `Machines` from
|
||||||
the repository
|
the repository
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -22,7 +22,7 @@ class INodeRepository(ABC):
|
||||||
:param src: The machine that the connection or communication originated from
|
:param src: The machine that the connection or communication originated from
|
||||||
:param dst: The machine that the src communicated with
|
:param dst: The machine that the src communicated with
|
||||||
:param communication_type: The way the machines communicated
|
:param communication_type: The way the machines communicated
|
||||||
:raises StorageError: If an error occurred while attempting to upsert the Node
|
:raises StorageError: If an error occurs while attempting to upsert the Node
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -31,7 +31,7 @@ class INodeRepository(ABC):
|
||||||
Return all nodes that are stored in the repository
|
Return all nodes that are stored in the repository
|
||||||
|
|
||||||
:return: All known Nodes
|
:return: All known Nodes
|
||||||
:raises RetrievalError: If an error occurred while attempting to retrieve the nodes
|
:raises RetrievalError: If an error occurs while attempting to retrieve the nodes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -39,6 +39,6 @@ class INodeRepository(ABC):
|
||||||
"""
|
"""
|
||||||
Removes all data from the repository
|
Removes all data from the repository
|
||||||
|
|
||||||
:raises RemovalError: If an error occurred while attempting to remove all `Nodes` from the
|
:raises RemovalError: If an error occurs while attempting to remove all `Nodes` from the
|
||||||
repository
|
repository
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -58,6 +58,14 @@ class MongoAgentRepository(IAgentRepository):
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
raise RetrievalError(f"Error retrieving running agents: {err}")
|
raise RetrievalError(f"Error retrieving running agents: {err}")
|
||||||
|
|
||||||
|
def get_progenitor(self, agent: Agent) -> Agent:
|
||||||
|
if agent.parent_id is None:
|
||||||
|
return agent
|
||||||
|
|
||||||
|
parent = self.get_agent_by_id(agent.parent_id)
|
||||||
|
|
||||||
|
return self.get_progenitor(parent)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
try:
|
try:
|
||||||
self._agents_collection.drop()
|
self._agents_collection.drop()
|
||||||
|
|
|
@ -10,3 +10,4 @@ from .pba_file_upload import PBAFileUpload, LINUX_PBA_TYPE, WINDOWS_PBA_TYPE
|
||||||
from .pba_file_download import PBAFileDownload
|
from .pba_file_download import PBAFileDownload
|
||||||
from .agent_events import AgentEvents
|
from .agent_events import AgentEvents
|
||||||
from .agents import Agents
|
from .agents import Agents
|
||||||
|
from .agent_signals import AgentSignals, TerminateAllAgents
|
||||||
|
|
|
@ -1,2 +0,0 @@
|
||||||
from .stop_all_agents import StopAllAgents
|
|
||||||
from .stop_agent_check import StopAgentCheck
|
|
|
@ -1,11 +0,0 @@
|
||||||
from monkey_island.cc.resources.AbstractResource import AbstractResource
|
|
||||||
from monkey_island.cc.services.infection_lifecycle import should_agent_die
|
|
||||||
|
|
||||||
|
|
||||||
class StopAgentCheck(AbstractResource):
|
|
||||||
# API Spec: Rename to AgentStopStatus or something, endpoint for this could be
|
|
||||||
# "/api/agents/<GUID>/stop-status"
|
|
||||||
urls = ["/api/monkey-control/needs-to-stop/<int:monkey_guid>"]
|
|
||||||
|
|
||||||
def get(self, monkey_guid: int):
|
|
||||||
return {"stop_agent": should_agent_die(monkey_guid)}
|
|
|
@ -1,27 +0,0 @@
|
||||||
import json
|
|
||||||
|
|
||||||
from flask import make_response, request
|
|
||||||
|
|
||||||
from monkey_island.cc.resources.AbstractResource import AbstractResource
|
|
||||||
from monkey_island.cc.resources.request_authentication import jwt_required
|
|
||||||
from monkey_island.cc.resources.utils.semaphores import agent_killing_mutex
|
|
||||||
from monkey_island.cc.services.infection_lifecycle import set_stop_all, should_agent_die
|
|
||||||
|
|
||||||
|
|
||||||
class StopAllAgents(AbstractResource):
|
|
||||||
# API Spec: This is an action and there's no "resource"; RPC-style endpoint?
|
|
||||||
urls = ["/api/monkey-control/stop-all-agents"]
|
|
||||||
|
|
||||||
@jwt_required
|
|
||||||
def post(self):
|
|
||||||
with agent_killing_mutex:
|
|
||||||
data = json.loads(request.data)
|
|
||||||
if data["kill_time"]:
|
|
||||||
set_stop_all(data["kill_time"])
|
|
||||||
return make_response({}, 200)
|
|
||||||
else:
|
|
||||||
return make_response({}, 400)
|
|
||||||
|
|
||||||
# API Spec: This is the exact same thing as what's in StopAgentCheck
|
|
||||||
def get(self, monkey_guid):
|
|
||||||
return {"stop_agent": should_agent_die(monkey_guid)}
|
|
|
@ -0,0 +1,2 @@
|
||||||
|
from .agent_signals import AgentSignals
|
||||||
|
from .terminate_all_agents import TerminateAllAgents
|
|
@ -0,0 +1,21 @@
|
||||||
|
import logging
|
||||||
|
from http import HTTPStatus
|
||||||
|
|
||||||
|
from monkey_island.cc.resources.AbstractResource import AbstractResource
|
||||||
|
from monkey_island.cc.services import AgentSignalsService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentSignals(AbstractResource):
|
||||||
|
urls = ["/api/agent-signals/<string:agent_id>"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
agent_signals_service: AgentSignalsService,
|
||||||
|
):
|
||||||
|
self._agent_signals_service = agent_signals_service
|
||||||
|
|
||||||
|
def get(self, agent_id: str):
|
||||||
|
agent_signals = self._agent_signals_service.get_signals(agent_id)
|
||||||
|
return agent_signals.dict(simplify=True), HTTPStatus.OK
|
|
@ -0,0 +1,39 @@
|
||||||
|
import logging
|
||||||
|
from http import HTTPStatus
|
||||||
|
from json import JSONDecodeError
|
||||||
|
|
||||||
|
from flask import request
|
||||||
|
|
||||||
|
from monkey_island.cc.event_queue import IIslandEventQueue, IslandEventTopic
|
||||||
|
from monkey_island.cc.resources.AbstractResource import AbstractResource
|
||||||
|
from monkey_island.cc.resources.request_authentication import jwt_required
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TerminateAllAgents(AbstractResource):
|
||||||
|
urls = ["/api/agent-signals/terminate-all-agents"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
island_event_queue: IIslandEventQueue,
|
||||||
|
):
|
||||||
|
self._island_event_queue = island_event_queue
|
||||||
|
|
||||||
|
@jwt_required
|
||||||
|
def post(self):
|
||||||
|
try:
|
||||||
|
terminate_timestamp = request.json["terminate_time"]
|
||||||
|
if terminate_timestamp is None:
|
||||||
|
raise ValueError("Terminate signal's timestamp is empty")
|
||||||
|
elif terminate_timestamp <= 0:
|
||||||
|
raise ValueError("Terminate signal's timestamp is not a positive integer")
|
||||||
|
|
||||||
|
self._island_event_queue.publish(
|
||||||
|
IslandEventTopic.TERMINATE_AGENTS, timestamp=terminate_timestamp
|
||||||
|
)
|
||||||
|
|
||||||
|
except (JSONDecodeError, TypeError, ValueError, KeyError) as err:
|
||||||
|
return {"error": err}, HTTPStatus.BAD_REQUEST
|
||||||
|
|
||||||
|
return {}, HTTPStatus.NO_CONTENT
|
|
@ -1,3 +1,4 @@
|
||||||
|
from .agent_signals_service import AgentSignalsService
|
||||||
from .authentication_service import AuthenticationService
|
from .authentication_service import AuthenticationService
|
||||||
|
|
||||||
from .aws import AWSService
|
from .aws import AWSService
|
||||||
|
|
|
@ -0,0 +1,55 @@
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from common.agent_signals import AgentSignals
|
||||||
|
from common.types import AgentID
|
||||||
|
from monkey_island.cc.models import Simulation
|
||||||
|
from monkey_island.cc.repository import IAgentRepository, ISimulationRepository
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentSignalsService:
|
||||||
|
def __init__(
|
||||||
|
self, simulation_repository: ISimulationRepository, agent_repository: IAgentRepository
|
||||||
|
):
|
||||||
|
self._simulation_repository = simulation_repository
|
||||||
|
self._agent_repository = agent_repository
|
||||||
|
|
||||||
|
def get_signals(self, agent_id: AgentID) -> AgentSignals:
|
||||||
|
"""
|
||||||
|
Gets the signals sent to a particular agent
|
||||||
|
|
||||||
|
:param agent_id: The ID of the agent whose signals need to be retrieved
|
||||||
|
:return: Signals sent to the relevant agent
|
||||||
|
"""
|
||||||
|
terminate_timestamp = self._get_terminate_signal_timestamp(agent_id)
|
||||||
|
return AgentSignals(terminate=terminate_timestamp)
|
||||||
|
|
||||||
|
def _get_terminate_signal_timestamp(self, agent_id: AgentID) -> Optional[datetime]:
|
||||||
|
simulation = self._simulation_repository.get_simulation()
|
||||||
|
terminate_all_signal_time = simulation.terminate_signal_time
|
||||||
|
if terminate_all_signal_time is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
agent = self._agent_repository.get_agent_by_id(agent_id)
|
||||||
|
if agent.start_time <= terminate_all_signal_time:
|
||||||
|
return terminate_all_signal_time
|
||||||
|
|
||||||
|
progenitor = self._agent_repository.get_progenitor(agent)
|
||||||
|
if progenitor.start_time <= terminate_all_signal_time:
|
||||||
|
return terminate_all_signal_time
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def on_terminate_agents_signal(self, timestamp: datetime):
|
||||||
|
"""
|
||||||
|
Updates the simulation repository with the terminate signal's timestamp
|
||||||
|
|
||||||
|
:param timestamp: Timestamp of the terminate signal
|
||||||
|
"""
|
||||||
|
simulation = self._simulation_repository.get_simulation()
|
||||||
|
updated_simulation = Simulation(mode=simulation.mode, terminate_signal_time=timestamp)
|
||||||
|
|
||||||
|
self._simulation_repository.save_simulation(updated_simulation)
|
|
@ -4,7 +4,6 @@ from flask import jsonify
|
||||||
|
|
||||||
from monkey_island.cc.database import mongo
|
from monkey_island.cc.database import mongo
|
||||||
from monkey_island.cc.models import Config
|
from monkey_island.cc.models import Config
|
||||||
from monkey_island.cc.models.agent_controls import AgentControls
|
|
||||||
from monkey_island.cc.models.attack.attack_mitigations import AttackMitigations
|
from monkey_island.cc.models.attack.attack_mitigations import AttackMitigations
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -24,7 +23,6 @@ class Database(object):
|
||||||
for x in mongo.db.collection_names()
|
for x in mongo.db.collection_names()
|
||||||
if Database._should_drop(x, reset_config)
|
if Database._should_drop(x, reset_config)
|
||||||
]
|
]
|
||||||
Database.init_agent_controls()
|
|
||||||
logger.info("DB was reset")
|
logger.info("DB was reset")
|
||||||
return jsonify(status="OK")
|
return jsonify(status="OK")
|
||||||
|
|
||||||
|
@ -44,10 +42,6 @@ class Database(object):
|
||||||
mongo.db[collection_name].drop()
|
mongo.db[collection_name].drop()
|
||||||
logger.info("Dropped collection {}".format(collection_name))
|
logger.info("Dropped collection {}".format(collection_name))
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def init_agent_controls():
|
|
||||||
AgentControls().save()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_mitigations_missing() -> bool:
|
def is_mitigations_missing() -> bool:
|
||||||
return bool(AttackMitigations.COLLECTION_NAME not in mongo.db.list_collection_names())
|
return bool(AttackMitigations.COLLECTION_NAME not in mongo.db.list_collection_names())
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from monkey_island.cc.models import Monkey
|
|
||||||
from monkey_island.cc.models.agent_controls import AgentControls
|
|
||||||
from monkey_island.cc.services.node import NodeService
|
from monkey_island.cc.services.node import NodeService
|
||||||
from monkey_island.cc.services.reporting.report import ReportService
|
from monkey_island.cc.services.reporting.report import ReportService
|
||||||
from monkey_island.cc.services.reporting.report_generation_synchronisation import (
|
from monkey_island.cc.services.reporting.report_generation_synchronisation import (
|
||||||
|
@ -12,41 +10,6 @@ from monkey_island.cc.services.reporting.report_generation_synchronisation impor
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def set_stop_all(time: float):
|
|
||||||
# This will use Agent and Simulation repositories
|
|
||||||
for monkey in Monkey.objects():
|
|
||||||
monkey.should_stop = True
|
|
||||||
monkey.save()
|
|
||||||
agent_controls = AgentControls.objects.first()
|
|
||||||
agent_controls.last_stop_all = time
|
|
||||||
agent_controls.save()
|
|
||||||
|
|
||||||
|
|
||||||
def should_agent_die(guid: int) -> bool:
|
|
||||||
monkey = Monkey.objects(guid=str(guid)).first()
|
|
||||||
return _should_agent_stop(monkey) or _is_monkey_killed_manually(monkey)
|
|
||||||
|
|
||||||
|
|
||||||
def _should_agent_stop(monkey: Monkey) -> bool:
|
|
||||||
if monkey.should_stop:
|
|
||||||
# Only stop the agent once, to allow further runs on that machine
|
|
||||||
monkey.should_stop = False
|
|
||||||
monkey.save()
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _is_monkey_killed_manually(monkey: Monkey) -> bool:
|
|
||||||
kill_timestamp = AgentControls.objects.first().last_stop_all
|
|
||||||
if kill_timestamp is None:
|
|
||||||
return False
|
|
||||||
if monkey.has_parent():
|
|
||||||
launch_timestamp = monkey.get_parent().launch_time
|
|
||||||
else:
|
|
||||||
launch_timestamp = monkey.launch_time
|
|
||||||
return int(kill_timestamp) >= int(launch_timestamp)
|
|
||||||
|
|
||||||
|
|
||||||
def get_completed_steps():
|
def get_completed_steps():
|
||||||
is_any_exists = NodeService.is_any_monkey_exists()
|
is_any_exists = NodeService.is_any_monkey_exists()
|
||||||
infection_done = NodeService.is_monkey_finished_running()
|
infection_done = NodeService.is_monkey_finished_running()
|
||||||
|
|
|
@ -47,7 +47,7 @@ from monkey_island.cc.repository import (
|
||||||
)
|
)
|
||||||
from monkey_island.cc.server_utils.consts import MONKEY_ISLAND_ABS_PATH
|
from monkey_island.cc.server_utils.consts import MONKEY_ISLAND_ABS_PATH
|
||||||
from monkey_island.cc.server_utils.encryption import ILockableEncryptor, RepositoryEncryptor
|
from monkey_island.cc.server_utils.encryption import ILockableEncryptor, RepositoryEncryptor
|
||||||
from monkey_island.cc.services import AWSService
|
from monkey_island.cc.services import AgentSignalsService, AWSService
|
||||||
from monkey_island.cc.services.attack.technique_reports.T1003 import T1003, T1003GetReportData
|
from monkey_island.cc.services.attack.technique_reports.T1003 import T1003, T1003GetReportData
|
||||||
from monkey_island.cc.services.run_local_monkey import LocalMonkeyRunService
|
from monkey_island.cc.services.run_local_monkey import LocalMonkeyRunService
|
||||||
from monkey_island.cc.setup.mongo.mongo_setup import MONGO_URL
|
from monkey_island.cc.setup.mongo.mongo_setup import MONGO_URL
|
||||||
|
@ -176,6 +176,7 @@ def _register_services(container: DIContainer):
|
||||||
container.register_instance(AWSService, container.resolve(AWSService))
|
container.register_instance(AWSService, container.resolve(AWSService))
|
||||||
container.register_instance(LocalMonkeyRunService, container.resolve(LocalMonkeyRunService))
|
container.register_instance(LocalMonkeyRunService, container.resolve(LocalMonkeyRunService))
|
||||||
container.register_instance(AuthenticationService, container.resolve(AuthenticationService))
|
container.register_instance(AuthenticationService, container.resolve(AuthenticationService))
|
||||||
|
container.register_instance(AgentSignalsService, container.resolve(AgentSignalsService))
|
||||||
|
|
||||||
|
|
||||||
def _dirty_hacks(container: DIContainer):
|
def _dirty_hacks(container: DIContainer):
|
||||||
|
|
|
@ -15,6 +15,7 @@ from monkey_island.cc.repository import (
|
||||||
INodeRepository,
|
INodeRepository,
|
||||||
ISimulationRepository,
|
ISimulationRepository,
|
||||||
)
|
)
|
||||||
|
from monkey_island.cc.services import AgentSignalsService
|
||||||
from monkey_island.cc.services.database import Database
|
from monkey_island.cc.services.database import Database
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,6 +26,7 @@ def setup_island_event_handlers(container: DIContainer):
|
||||||
_subscribe_reset_agent_configuration_events(island_event_queue, container)
|
_subscribe_reset_agent_configuration_events(island_event_queue, container)
|
||||||
_subscribe_clear_simulation_data_events(island_event_queue, container)
|
_subscribe_clear_simulation_data_events(island_event_queue, container)
|
||||||
_subscribe_set_island_mode_events(island_event_queue, container)
|
_subscribe_set_island_mode_events(island_event_queue, container)
|
||||||
|
_subscribe_terminate_agents_events(island_event_queue, container)
|
||||||
|
|
||||||
|
|
||||||
def _subscribe_agent_registration_events(
|
def _subscribe_agent_registration_events(
|
||||||
|
@ -74,3 +76,13 @@ def _subscribe_set_island_mode_events(
|
||||||
|
|
||||||
simulation_repository = container.resolve(ISimulationRepository)
|
simulation_repository = container.resolve(ISimulationRepository)
|
||||||
island_event_queue.subscribe(topic, simulation_repository.set_mode)
|
island_event_queue.subscribe(topic, simulation_repository.set_mode)
|
||||||
|
|
||||||
|
|
||||||
|
def _subscribe_terminate_agents_events(
|
||||||
|
island_event_queue: IIslandEventQueue, container: DIContainer
|
||||||
|
):
|
||||||
|
topic = IslandEventTopic.TERMINATE_AGENTS
|
||||||
|
|
||||||
|
agent_signals_service = container.resolve(AgentSignalsService)
|
||||||
|
|
||||||
|
island_event_queue.subscribe(topic, agent_signals_service.on_terminate_agents_signal)
|
||||||
|
|
|
@ -84,12 +84,12 @@ class MapPageComponent extends AuthComponent {
|
||||||
}
|
}
|
||||||
|
|
||||||
killAllMonkeys = () => {
|
killAllMonkeys = () => {
|
||||||
this.authFetch('/api/monkey-control/stop-all-agents',
|
this.authFetch('/api/agent-signals/terminate-all-agents',
|
||||||
{
|
{
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {'Content-Type': 'application/json'},
|
headers: {'Content-Type': 'application/json'},
|
||||||
// Python uses floating point seconds, Date.now uses milliseconds, so convert
|
// Python uses floating point seconds, Date.now uses milliseconds, so convert
|
||||||
body: JSON.stringify({kill_time: Date.now() / 1000.0})
|
body: JSON.stringify({terminate_time: Date.now() / 1000.0})
|
||||||
})
|
})
|
||||||
.then(res => res.json())
|
.then(res => res.json())
|
||||||
.then(() => {this.setState({killPressed: true})});
|
.then(() => {this.setState({killPressed: true})});
|
||||||
|
|
|
@ -4,7 +4,7 @@ import pytest
|
||||||
import requests
|
import requests
|
||||||
import requests_mock
|
import requests_mock
|
||||||
|
|
||||||
from common import OperatingSystem
|
from common import AgentSignals, OperatingSystem
|
||||||
from common.agent_event_serializers import (
|
from common.agent_event_serializers import (
|
||||||
AgentEventSerializerRegistry,
|
AgentEventSerializerRegistry,
|
||||||
PydanticAgentEventSerializer,
|
PydanticAgentEventSerializer,
|
||||||
|
@ -33,6 +33,8 @@ AGENT_REGISTRATION = AgentRegistrationData(
|
||||||
network_interfaces=[],
|
network_interfaces=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
TIMESTAMP = 123456789
|
||||||
|
|
||||||
ISLAND_URI = f"https://{SERVER}/api?action=is-up"
|
ISLAND_URI = f"https://{SERVER}/api?action=is-up"
|
||||||
ISLAND_SEND_LOG_URI = f"https://{SERVER}/api/log"
|
ISLAND_SEND_LOG_URI = f"https://{SERVER}/api/log"
|
||||||
ISLAND_GET_PBA_FILE_URI = f"https://{SERVER}/api/pba/download/{PBA_FILE}"
|
ISLAND_GET_PBA_FILE_URI = f"https://{SERVER}/api/pba/download/{PBA_FILE}"
|
||||||
|
@ -42,6 +44,7 @@ ISLAND_REGISTER_AGENT_URI = f"https://{SERVER}/api/agents"
|
||||||
ISLAND_AGENT_STOP_URI = f"https://{SERVER}/api/monkey-control/needs-to-stop/{AGENT_ID}"
|
ISLAND_AGENT_STOP_URI = f"https://{SERVER}/api/monkey-control/needs-to-stop/{AGENT_ID}"
|
||||||
ISLAND_GET_CONFIG_URI = f"https://{SERVER}/api/agent-configuration"
|
ISLAND_GET_CONFIG_URI = f"https://{SERVER}/api/agent-configuration"
|
||||||
ISLAND_GET_PROPAGATION_CREDENTIALS_URI = f"https://{SERVER}/api/propagation-credentials"
|
ISLAND_GET_PROPAGATION_CREDENTIALS_URI = f"https://{SERVER}/api/propagation-credentials"
|
||||||
|
ISLAND_GET_AGENT_SIGNALS = f"https://{SERVER}/api/agent-signals/{AGENT_ID}"
|
||||||
|
|
||||||
|
|
||||||
class Event1(AbstractAgentEvent):
|
class Event1(AbstractAgentEvent):
|
||||||
|
@ -325,52 +328,6 @@ def test_island_api_client_register_agent__status_code(
|
||||||
island_api_client.register_agent(AGENT_REGISTRATION)
|
island_api_client.register_agent(AGENT_REGISTRATION)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"actual_error, expected_error",
|
|
||||||
[
|
|
||||||
(requests.exceptions.ConnectionError, IslandAPIConnectionError),
|
|
||||||
(TimeoutError, IslandAPITimeoutError),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_island_api_client__should_agent_stop(island_api_client, actual_error, expected_error):
|
|
||||||
with requests_mock.Mocker() as m:
|
|
||||||
m.get(ISLAND_URI)
|
|
||||||
island_api_client.connect(SERVER)
|
|
||||||
|
|
||||||
with pytest.raises(expected_error):
|
|
||||||
m.get(ISLAND_AGENT_STOP_URI, exc=actual_error)
|
|
||||||
island_api_client.should_agent_stop(AGENT_ID)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"status_code, expected_error",
|
|
||||||
[
|
|
||||||
(401, IslandAPIRequestError),
|
|
||||||
(501, IslandAPIRequestFailedError),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_island_api_client_should_agent_stop__status_code(
|
|
||||||
island_api_client, status_code, expected_error
|
|
||||||
):
|
|
||||||
with requests_mock.Mocker() as m:
|
|
||||||
m.get(ISLAND_URI)
|
|
||||||
island_api_client.connect(SERVER)
|
|
||||||
|
|
||||||
with pytest.raises(expected_error):
|
|
||||||
m.get(ISLAND_AGENT_STOP_URI, status_code=status_code)
|
|
||||||
island_api_client.should_agent_stop(AGENT_ID)
|
|
||||||
|
|
||||||
|
|
||||||
def test_island_api_client_should_agent_stop__bad_json(island_api_client):
|
|
||||||
with requests_mock.Mocker() as m:
|
|
||||||
m.get(ISLAND_URI)
|
|
||||||
island_api_client.connect(SERVER)
|
|
||||||
|
|
||||||
with pytest.raises(IslandAPIRequestFailedError):
|
|
||||||
m.get(ISLAND_AGENT_STOP_URI, content=b"bad")
|
|
||||||
island_api_client.should_agent_stop(AGENT_ID)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"actual_error, expected_error",
|
"actual_error, expected_error",
|
||||||
[
|
[
|
||||||
|
@ -461,3 +418,62 @@ def test_island_api_client_get_credentials_for_propagation__bad_json(island_api_
|
||||||
with pytest.raises(IslandAPIRequestFailedError):
|
with pytest.raises(IslandAPIRequestFailedError):
|
||||||
m.get(ISLAND_GET_PROPAGATION_CREDENTIALS_URI, content=b"bad")
|
m.get(ISLAND_GET_PROPAGATION_CREDENTIALS_URI, content=b"bad")
|
||||||
island_api_client.get_credentials_for_propagation()
|
island_api_client.get_credentials_for_propagation()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"actual_error, expected_error",
|
||||||
|
[
|
||||||
|
(requests.exceptions.ConnectionError, IslandAPIConnectionError),
|
||||||
|
(TimeoutError, IslandAPITimeoutError),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_island_api_client__get_agent_signals(island_api_client, actual_error, expected_error):
|
||||||
|
with requests_mock.Mocker() as m:
|
||||||
|
m.get(ISLAND_URI)
|
||||||
|
island_api_client.connect(SERVER)
|
||||||
|
|
||||||
|
with pytest.raises(expected_error):
|
||||||
|
m.get(ISLAND_GET_AGENT_SIGNALS, exc=actual_error)
|
||||||
|
island_api_client.get_agent_signals(agent_id=AGENT_ID)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"status_code, expected_error",
|
||||||
|
[
|
||||||
|
(401, IslandAPIRequestError),
|
||||||
|
(501, IslandAPIRequestFailedError),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_island_api_client_get_agent_signals__status_code(
|
||||||
|
island_api_client, status_code, expected_error
|
||||||
|
):
|
||||||
|
with requests_mock.Mocker() as m:
|
||||||
|
m.get(ISLAND_URI)
|
||||||
|
island_api_client.connect(SERVER)
|
||||||
|
|
||||||
|
with pytest.raises(expected_error):
|
||||||
|
m.get(ISLAND_GET_AGENT_SIGNALS, status_code=status_code)
|
||||||
|
island_api_client.get_agent_signals(agent_id=AGENT_ID)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("timestamp", [TIMESTAMP, None])
|
||||||
|
def test_island_api_client_get_agent_signals(island_api_client, timestamp):
|
||||||
|
expected_agent_signals = AgentSignals(terminate=timestamp)
|
||||||
|
with requests_mock.Mocker() as m:
|
||||||
|
m.get(ISLAND_URI)
|
||||||
|
island_api_client.connect(SERVER)
|
||||||
|
|
||||||
|
m.get(ISLAND_GET_AGENT_SIGNALS, json={"terminate": timestamp})
|
||||||
|
actual_agent_signals = island_api_client.get_agent_signals(agent_id=AGENT_ID)
|
||||||
|
|
||||||
|
assert actual_agent_signals == expected_agent_signals
|
||||||
|
|
||||||
|
|
||||||
|
def test_island_api_client_get_agent_signals__bad_json(island_api_client):
|
||||||
|
with requests_mock.Mocker() as m:
|
||||||
|
m.get(ISLAND_URI)
|
||||||
|
island_api_client.connect(SERVER)
|
||||||
|
|
||||||
|
with pytest.raises(IslandAPIError):
|
||||||
|
m.get(ISLAND_GET_AGENT_SIGNALS, json={"bogus": "vogus"})
|
||||||
|
island_api_client.get_agent_signals(agent_id=AGENT_ID)
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
|
from typing import Optional
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from infection_monkey.i_control_channel import IslandCommunicationError
|
from common import AgentSignals
|
||||||
|
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
|
||||||
from infection_monkey.island_api_client import (
|
from infection_monkey.island_api_client import (
|
||||||
IIslandAPIClient,
|
IIslandAPIClient,
|
||||||
IslandAPIConnectionError,
|
IslandAPIConnectionError,
|
||||||
|
@ -33,16 +35,24 @@ def control_channel(island_api_client) -> ControlChannel:
|
||||||
return ControlChannel(SERVER, AGENT_ID, island_api_client)
|
return ControlChannel(SERVER, AGENT_ID, island_api_client)
|
||||||
|
|
||||||
|
|
||||||
def test_control_channel__should_agent_stop(control_channel, island_api_client):
|
@pytest.mark.parametrize("signal_time,expected_should_stop", [(1663950115, True), (None, False)])
|
||||||
control_channel.should_agent_stop()
|
def test_control_channel__should_agent_stop(
|
||||||
assert island_api_client.should_agent_stop.called_once()
|
control_channel: IControlChannel,
|
||||||
|
island_api_client: IIslandAPIClient,
|
||||||
|
signal_time: Optional[int],
|
||||||
|
expected_should_stop: bool,
|
||||||
|
):
|
||||||
|
island_api_client.get_agent_signals = MagicMock(
|
||||||
|
return_value=AgentSignals(terminate=signal_time)
|
||||||
|
)
|
||||||
|
assert control_channel.should_agent_stop() is expected_should_stop
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("api_error", CONTROL_CHANNEL_API_ERRORS)
|
@pytest.mark.parametrize("api_error", CONTROL_CHANNEL_API_ERRORS)
|
||||||
def test_control_channel__should_agent_stop_raises_error(
|
def test_control_channel__should_agent_stop_raises_error(
|
||||||
control_channel, island_api_client, api_error
|
control_channel, island_api_client, api_error
|
||||||
):
|
):
|
||||||
island_api_client.should_agent_stop.side_effect = api_error()
|
island_api_client.get_agent_signals.side_effect = api_error()
|
||||||
|
|
||||||
with pytest.raises(IslandCommunicationError):
|
with pytest.raises(IslandCommunicationError):
|
||||||
control_channel.should_agent_stop()
|
control_channel.should_agent_stop()
|
||||||
|
|
|
@ -3,7 +3,7 @@ import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from monkey_island.cc.models.monkey import Monkey, MonkeyNotFoundError, ParentNotFoundError
|
from monkey_island.cc.models.monkey import Monkey, MonkeyNotFoundError
|
||||||
from monkey_island.cc.models.monkey_ttl import MonkeyTtl
|
from monkey_island.cc.models.monkey_ttl import MonkeyTtl
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -162,35 +162,3 @@ class TestMonkey:
|
||||||
|
|
||||||
cache_info_after_query = Monkey.is_monkey.storage.backend.cache_info()
|
cache_info_after_query = Monkey.is_monkey.storage.backend.cache_info()
|
||||||
assert cache_info_after_query.hits == 2
|
assert cache_info_after_query.hits == 2
|
||||||
|
|
||||||
@pytest.mark.usefixtures("uses_database")
|
|
||||||
def test_has_parent(self):
|
|
||||||
monkey_1 = Monkey(guid=str(uuid.uuid4()))
|
|
||||||
monkey_2 = Monkey(guid=str(uuid.uuid4()))
|
|
||||||
monkey_1.parent = [[monkey_2.guid]]
|
|
||||||
monkey_1.save()
|
|
||||||
assert monkey_1.has_parent()
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("uses_database")
|
|
||||||
def test_has_no_parent(self):
|
|
||||||
monkey_1 = Monkey(guid=str(uuid.uuid4()))
|
|
||||||
monkey_1.parent = [[monkey_1.guid]]
|
|
||||||
monkey_1.save()
|
|
||||||
assert not monkey_1.has_parent()
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("uses_database")
|
|
||||||
def test_get_parent(self):
|
|
||||||
monkey_1 = Monkey(guid=str(uuid.uuid4()))
|
|
||||||
monkey_2 = Monkey(guid=str(uuid.uuid4()))
|
|
||||||
monkey_1.parent = [[monkey_2.guid]]
|
|
||||||
monkey_1.save()
|
|
||||||
monkey_2.save()
|
|
||||||
assert monkey_1.get_parent().guid == monkey_2.guid
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("uses_database")
|
|
||||||
def test_get_parent_no_parent(self):
|
|
||||||
monkey_1 = Monkey(guid=str(uuid.uuid4()))
|
|
||||||
monkey_1.parent = [[monkey_1.guid]]
|
|
||||||
monkey_1.save()
|
|
||||||
with pytest.raises(ParentNotFoundError):
|
|
||||||
monkey_1.get_parent()
|
|
||||||
|
|
|
@ -17,14 +17,29 @@ from monkey_island.cc.repository import (
|
||||||
)
|
)
|
||||||
|
|
||||||
VICTIM_ZERO_ID = uuid4()
|
VICTIM_ZERO_ID = uuid4()
|
||||||
|
VICTIM_TWO_ID = uuid4()
|
||||||
|
VICTIM_THREE_ID = uuid4()
|
||||||
|
|
||||||
|
PROGENITOR_AGENT = Agent(
|
||||||
|
id=VICTIM_ZERO_ID, machine_id=1, start_time=datetime.fromtimestamp(1661856718)
|
||||||
|
)
|
||||||
|
|
||||||
|
DESCENDANT_AGENT = Agent(
|
||||||
|
id=VICTIM_THREE_ID,
|
||||||
|
machine_id=4,
|
||||||
|
start_time=datetime.fromtimestamp(1661856868),
|
||||||
|
parent_id=VICTIM_TWO_ID,
|
||||||
|
)
|
||||||
|
|
||||||
RUNNING_AGENTS = (
|
RUNNING_AGENTS = (
|
||||||
Agent(id=VICTIM_ZERO_ID, machine_id=1, start_time=datetime.fromtimestamp(1661856718)),
|
PROGENITOR_AGENT,
|
||||||
Agent(
|
Agent(
|
||||||
id=uuid4(),
|
id=VICTIM_TWO_ID,
|
||||||
machine_id=2,
|
machine_id=2,
|
||||||
start_time=datetime.fromtimestamp(1661856818),
|
start_time=datetime.fromtimestamp(1661856818),
|
||||||
parent_id=VICTIM_ZERO_ID,
|
parent_id=VICTIM_ZERO_ID,
|
||||||
),
|
),
|
||||||
|
DESCENDANT_AGENT,
|
||||||
)
|
)
|
||||||
STOPPED_AGENTS = (
|
STOPPED_AGENTS = (
|
||||||
Agent(
|
Agent(
|
||||||
|
@ -172,6 +187,24 @@ def test_get_running_agents__retrieval_error(error_raising_agent_repository):
|
||||||
error_raising_agent_repository.get_running_agents()
|
error_raising_agent_repository.get_running_agents()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("agent", [DESCENDANT_AGENT, PROGENITOR_AGENT])
|
||||||
|
def test_get_progenitor(agent_repository, agent):
|
||||||
|
actual_progenitor = agent_repository.get_progenitor(agent)
|
||||||
|
|
||||||
|
assert actual_progenitor == PROGENITOR_AGENT
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_progenitor__id_not_found(agent_repository):
|
||||||
|
dummy_agent = Agent(id=uuid4(), machine_id=10, start_time=datetime.now(), parent_id=uuid4())
|
||||||
|
with pytest.raises(UnknownRecordError):
|
||||||
|
agent_repository.get_progenitor(dummy_agent)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_progenitor__retrieval_error(error_raising_agent_repository):
|
||||||
|
with pytest.raises(RetrievalError):
|
||||||
|
error_raising_agent_repository.get_progenitor(AGENTS[1])
|
||||||
|
|
||||||
|
|
||||||
def test_reset(agent_repository):
|
def test_reset(agent_repository):
|
||||||
# Ensure the repository is not empty
|
# Ensure the repository is not empty
|
||||||
for agent in AGENTS:
|
for agent in AGENTS:
|
||||||
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
from http import HTTPStatus
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from tests.common import StubDIContainer
|
||||||
|
|
||||||
|
from common.agent_signals import AgentSignals as Signals
|
||||||
|
from monkey_island.cc.repository import RetrievalError, StorageError
|
||||||
|
from monkey_island.cc.services import AgentSignalsService
|
||||||
|
|
||||||
|
TIMESTAMP_1 = 123456789
|
||||||
|
TIMESTAMP_2 = 123546789
|
||||||
|
|
||||||
|
SIGNALS_1 = Signals(terminate=TIMESTAMP_1)
|
||||||
|
SIGNALS_2 = Signals(terminate=TIMESTAMP_2)
|
||||||
|
|
||||||
|
AGENT_ID_1 = UUID("c0dd10b3-e21a-4da9-9d96-a99c19ebd7c5")
|
||||||
|
AGENT_ID_2 = UUID("9b4279f6-6ec5-4953-821e-893ddc71a988")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_agent_signals_service():
|
||||||
|
return MagicMock(spec=AgentSignalsService)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def flask_client(build_flask_client, mock_agent_signals_service):
|
||||||
|
container = StubDIContainer()
|
||||||
|
container.register_instance(AgentSignalsService, mock_agent_signals_service)
|
||||||
|
|
||||||
|
with build_flask_client(container) as flask_client:
|
||||||
|
yield flask_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"url, signals",
|
||||||
|
[
|
||||||
|
(f"/api/agent-signals/{AGENT_ID_1}", SIGNALS_1),
|
||||||
|
(f"/api/agent-signals/{AGENT_ID_2}", SIGNALS_2),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_agent_signals_get(flask_client, mock_agent_signals_service, url, signals):
|
||||||
|
mock_agent_signals_service.get_signals.return_value = signals
|
||||||
|
resp = flask_client.get(url, follow_redirects=True)
|
||||||
|
assert resp.status_code == HTTPStatus.OK
|
||||||
|
assert resp.json == signals.dict(simplify=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"url, error",
|
||||||
|
[
|
||||||
|
(f"/api/agent-signals/{AGENT_ID_1}", RetrievalError),
|
||||||
|
(f"/api/agent-signals/{AGENT_ID_2}", StorageError),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_agent_signals_get__internal_server_error(
|
||||||
|
flask_client, mock_agent_signals_service, url, error
|
||||||
|
):
|
||||||
|
mock_agent_signals_service.get_signals.side_effect = error
|
||||||
|
resp = flask_client.get(url, follow_redirects=True)
|
||||||
|
assert resp.status_code == HTTPStatus.INTERNAL_SERVER_ERROR
|
|
@ -0,0 +1,51 @@
|
||||||
|
from http import HTTPStatus
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from tests.common import StubDIContainer
|
||||||
|
|
||||||
|
from monkey_island.cc.event_queue import IIslandEventQueue
|
||||||
|
from monkey_island.cc.resources import TerminateAllAgents
|
||||||
|
|
||||||
|
TIMESTAMP = 123456789
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def flask_client(build_flask_client):
|
||||||
|
container = StubDIContainer()
|
||||||
|
|
||||||
|
mock_island_event_queue = MagicMock(spec=IIslandEventQueue)
|
||||||
|
mock_island_event_queue.publish.side_effect = None
|
||||||
|
container.register_instance(IIslandEventQueue, mock_island_event_queue)
|
||||||
|
|
||||||
|
with build_flask_client(container) as flask_client:
|
||||||
|
yield flask_client
|
||||||
|
|
||||||
|
|
||||||
|
def test_terminate_all_agents_post(flask_client):
|
||||||
|
resp = flask_client.post(
|
||||||
|
TerminateAllAgents.urls[0],
|
||||||
|
json={"terminate_time": TIMESTAMP},
|
||||||
|
follow_redirects=True,
|
||||||
|
)
|
||||||
|
assert resp.status_code == HTTPStatus.NO_CONTENT
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"bad_data",
|
||||||
|
[
|
||||||
|
"bad timestamp",
|
||||||
|
{},
|
||||||
|
{"wrong_key": TIMESTAMP},
|
||||||
|
TIMESTAMP,
|
||||||
|
{"terminate_time": 0},
|
||||||
|
{"terminate_time": -1},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_terminate_all_agents_post__invalid_timestamp(flask_client, bad_data):
|
||||||
|
resp = flask_client.post(
|
||||||
|
TerminateAllAgents.urls[0],
|
||||||
|
json=bad_data,
|
||||||
|
follow_redirects=True,
|
||||||
|
)
|
||||||
|
assert resp.status_code == HTTPStatus.BAD_REQUEST
|
|
@ -0,0 +1,144 @@
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from common.types import AgentID
|
||||||
|
from monkey_island.cc.models import Agent, IslandMode, Simulation
|
||||||
|
from monkey_island.cc.repository import IAgentRepository, ISimulationRepository, UnknownRecordError
|
||||||
|
from monkey_island.cc.services import AgentSignalsService
|
||||||
|
|
||||||
|
AGENT_1 = Agent(
|
||||||
|
id=UUID("f811ad00-5a68-4437-bd51-7b5cc1768ad5"),
|
||||||
|
machine_id=1,
|
||||||
|
start_time=100,
|
||||||
|
parent_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
AGENT_2 = Agent(
|
||||||
|
id=UUID("012e7238-7b81-4108-8c7f-0787bc3f3c10"),
|
||||||
|
machine_id=2,
|
||||||
|
start_time=200,
|
||||||
|
parent_id=AGENT_1.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
AGENT_3 = Agent(
|
||||||
|
id=UUID("0fc9afcb-1902-436b-bd5c-1ad194252484"),
|
||||||
|
machine_id=3,
|
||||||
|
start_time=300,
|
||||||
|
parent_id=AGENT_2.id,
|
||||||
|
)
|
||||||
|
AGENTS = [AGENT_1, AGENT_2, AGENT_3]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_simulation_repository() -> IAgentRepository:
|
||||||
|
return MagicMock(spec=ISimulationRepository)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def mock_agent_repository() -> IAgentRepository:
|
||||||
|
def get_agent_by_id(agent_id: AgentID) -> Agent:
|
||||||
|
for agent in AGENTS:
|
||||||
|
if agent.id == agent_id:
|
||||||
|
return agent
|
||||||
|
|
||||||
|
raise UnknownRecordError(str(agent_id))
|
||||||
|
|
||||||
|
agent_repository = MagicMock(spec=IAgentRepository)
|
||||||
|
agent_repository.get_progenitor = MagicMock(return_value=AGENT_1)
|
||||||
|
agent_repository.get_agent_by_id = MagicMock(side_effect=get_agent_by_id)
|
||||||
|
|
||||||
|
return agent_repository
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def agent_signals_service(mock_simulation_repository, mock_agent_repository) -> AgentSignalsService:
|
||||||
|
return AgentSignalsService(mock_simulation_repository, mock_agent_repository)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("agent", AGENTS)
|
||||||
|
def test_terminate_is_none(
|
||||||
|
agent,
|
||||||
|
agent_signals_service: AgentSignalsService,
|
||||||
|
mock_simulation_repository: ISimulationRepository,
|
||||||
|
):
|
||||||
|
mock_simulation_repository.get_simulation = MagicMock(
|
||||||
|
return_value=Simulation(terminate_signal_time=None)
|
||||||
|
)
|
||||||
|
|
||||||
|
signals = agent_signals_service.get_signals(agent.id)
|
||||||
|
assert signals.terminate is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("agent", AGENTS)
|
||||||
|
def test_agent_started_before_terminate(
|
||||||
|
agent,
|
||||||
|
agent_signals_service: AgentSignalsService,
|
||||||
|
mock_simulation_repository: ISimulationRepository,
|
||||||
|
):
|
||||||
|
TERMINATE_TIMESTAMP = 400
|
||||||
|
mock_simulation_repository.get_simulation = MagicMock(
|
||||||
|
return_value=Simulation(terminate_signal_time=TERMINATE_TIMESTAMP)
|
||||||
|
)
|
||||||
|
|
||||||
|
signals = agent_signals_service.get_signals(agent.id)
|
||||||
|
|
||||||
|
assert signals.terminate.timestamp() == TERMINATE_TIMESTAMP
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("agent", AGENTS)
|
||||||
|
def test_agent_started_after_terminate(
|
||||||
|
agent,
|
||||||
|
agent_signals_service: AgentSignalsService,
|
||||||
|
mock_simulation_repository: ISimulationRepository,
|
||||||
|
):
|
||||||
|
TERMINATE_TIMESTAMP = 50
|
||||||
|
mock_simulation_repository.get_simulation = MagicMock(
|
||||||
|
return_value=Simulation(terminate_signal_time=TERMINATE_TIMESTAMP)
|
||||||
|
)
|
||||||
|
|
||||||
|
signals = agent_signals_service.get_signals(agent.id)
|
||||||
|
|
||||||
|
assert signals.terminate is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("agent", AGENTS)
|
||||||
|
def test_progenitor_started_before_terminate(
|
||||||
|
agent,
|
||||||
|
agent_signals_service: AgentSignalsService,
|
||||||
|
mock_simulation_repository: ISimulationRepository,
|
||||||
|
):
|
||||||
|
TERMINATE_TIMESTAMP = 150
|
||||||
|
mock_simulation_repository.get_simulation = MagicMock(
|
||||||
|
return_value=Simulation(terminate_signal_time=TERMINATE_TIMESTAMP)
|
||||||
|
)
|
||||||
|
|
||||||
|
signals = agent_signals_service.get_signals(agent.id)
|
||||||
|
|
||||||
|
assert signals.terminate.timestamp() == TERMINATE_TIMESTAMP
|
||||||
|
|
||||||
|
|
||||||
|
def test_on_terminate_agents_signal__stores_timestamp(
|
||||||
|
agent_signals_service: AgentSignalsService, mock_simulation_repository: ISimulationRepository
|
||||||
|
):
|
||||||
|
timestamp = 100
|
||||||
|
mock_simulation_repository.get_simulation = MagicMock(return_value=Simulation())
|
||||||
|
agent_signals_service.on_terminate_agents_signal(timestamp)
|
||||||
|
|
||||||
|
expected_value = Simulation(terminate_signal_time=timestamp)
|
||||||
|
assert mock_simulation_repository.save_simulation.called_once_with(expected_value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_on_terminate_agents_signal__updates_timestamp(
|
||||||
|
agent_signals_service: AgentSignalsService, mock_simulation_repository: ISimulationRepository
|
||||||
|
):
|
||||||
|
timestamp = 100
|
||||||
|
mock_simulation_repository.get_simulation = MagicMock(
|
||||||
|
return_value=Simulation(mode=IslandMode.RANSOMWARE, terminate_signal_time=50)
|
||||||
|
)
|
||||||
|
|
||||||
|
agent_signals_service.on_terminate_agents_signal(timestamp)
|
||||||
|
|
||||||
|
expected_value = Simulation(mode=IslandMode.RANSOMWARE, terminate_signal_time=timestamp)
|
||||||
|
assert mock_simulation_repository.save_simulation.called_once_with(expected_value)
|
|
@ -1,26 +1,6 @@
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from monkey_island.cc.models import Config, Monkey
|
from monkey_island.cc.models import Config, Monkey
|
||||||
from monkey_island.cc.models.agent_controls import AgentControls
|
|
||||||
from monkey_island.cc.services.infection_lifecycle import should_agent_die
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("uses_database")
|
|
||||||
def test_should_agent_die_by_config(monkeypatch):
|
|
||||||
monkey = Monkey(guid=str(uuid.uuid4()))
|
|
||||||
monkey.config = Config()
|
|
||||||
monkey.should_stop = True
|
|
||||||
monkey.save()
|
|
||||||
assert should_agent_die(monkey.guid)
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"monkey_island.cc.services.infection_lifecycle._is_monkey_killed_manually", lambda _: False
|
|
||||||
)
|
|
||||||
monkey.should_stop = True
|
|
||||||
monkey.save()
|
|
||||||
assert not should_agent_die(monkey.guid)
|
|
||||||
|
|
||||||
|
|
||||||
def create_monkey(launch_time):
|
def create_monkey(launch_time):
|
||||||
|
@ -32,80 +12,9 @@ def create_monkey(launch_time):
|
||||||
return monkey
|
return monkey
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("uses_database")
|
|
||||||
def test_should_agent_die_no_kill_event():
|
|
||||||
monkey = create_monkey(launch_time=3)
|
|
||||||
kill_event = AgentControls()
|
|
||||||
kill_event.save()
|
|
||||||
assert not should_agent_die(monkey.guid)
|
|
||||||
|
|
||||||
|
|
||||||
def create_kill_event(event_time):
|
|
||||||
kill_event = AgentControls(last_stop_all=event_time)
|
|
||||||
kill_event.save()
|
|
||||||
return kill_event
|
|
||||||
|
|
||||||
|
|
||||||
def create_parent(child_monkey, launch_time):
|
def create_parent(child_monkey, launch_time):
|
||||||
monkey_parent = Monkey(guid=str(uuid.uuid4()))
|
monkey_parent = Monkey(guid=str(uuid.uuid4()))
|
||||||
child_monkey.parent = [[monkey_parent.guid]]
|
child_monkey.parent = [[monkey_parent.guid]]
|
||||||
monkey_parent.launch_time = launch_time
|
monkey_parent.launch_time = launch_time
|
||||||
monkey_parent.save()
|
monkey_parent.save()
|
||||||
child_monkey.save()
|
child_monkey.save()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("uses_database")
|
|
||||||
def test_was_agent_killed_manually(monkeypatch):
|
|
||||||
monkey = create_monkey(launch_time=2)
|
|
||||||
|
|
||||||
create_kill_event(event_time=3)
|
|
||||||
|
|
||||||
assert should_agent_die(monkey.guid)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("uses_database")
|
|
||||||
def test_agent_killed_on_wakeup(monkeypatch):
|
|
||||||
monkey = create_monkey(launch_time=2)
|
|
||||||
|
|
||||||
create_kill_event(event_time=2)
|
|
||||||
|
|
||||||
assert should_agent_die(monkey.guid)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("uses_database")
|
|
||||||
def test_manual_kill_dont_affect_new_monkeys(monkeypatch):
|
|
||||||
monkey = create_monkey(launch_time=3)
|
|
||||||
|
|
||||||
create_kill_event(event_time=2)
|
|
||||||
|
|
||||||
assert not should_agent_die(monkey.guid)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("uses_database")
|
|
||||||
def test_parent_manually_killed(monkeypatch):
|
|
||||||
monkey = create_monkey(launch_time=3)
|
|
||||||
create_parent(child_monkey=monkey, launch_time=1)
|
|
||||||
|
|
||||||
create_kill_event(event_time=2)
|
|
||||||
|
|
||||||
assert should_agent_die(monkey.guid)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("uses_database")
|
|
||||||
def test_parent_manually_killed_on_wakeup(monkeypatch):
|
|
||||||
monkey = create_monkey(launch_time=3)
|
|
||||||
create_parent(child_monkey=monkey, launch_time=2)
|
|
||||||
|
|
||||||
create_kill_event(event_time=2)
|
|
||||||
|
|
||||||
assert should_agent_die(monkey.guid)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("uses_database")
|
|
||||||
def test_manual_kill_dont_affect_new_monkeys_with_parent(monkeypatch):
|
|
||||||
monkey = create_monkey(launch_time=3)
|
|
||||||
create_parent(child_monkey=monkey, launch_time=2)
|
|
||||||
|
|
||||||
create_kill_event(event_time=1)
|
|
||||||
|
|
||||||
assert not should_agent_die(monkey.guid)
|
|
||||||
|
|
|
@ -27,6 +27,7 @@ from monkey_island.cc.repository.i_simulation_repository import ISimulationRepos
|
||||||
from monkey_island.cc.repository.ICredentials import ICredentialsRepository
|
from monkey_island.cc.repository.ICredentials import ICredentialsRepository
|
||||||
from monkey_island.cc.repository.zero_trust.IEventRepository import IEventRepository
|
from monkey_island.cc.repository.zero_trust.IEventRepository import IEventRepository
|
||||||
from monkey_island.cc.repository.zero_trust.IFindingRepository import IFindingRepository
|
from monkey_island.cc.repository.zero_trust.IFindingRepository import IFindingRepository
|
||||||
|
from monkey_island.cc.services import AgentSignalsService
|
||||||
|
|
||||||
fake_monkey_dir_path # unused variable (monkey/tests/infection_monkey/post_breach/actions/test_users_custom_pba.py:37)
|
fake_monkey_dir_path # unused variable (monkey/tests/infection_monkey/post_breach/actions/test_users_custom_pba.py:37)
|
||||||
set_os_linux # unused variable (monkey/tests/infection_monkey/post_breach/actions/test_users_custom_pba.py:37)
|
set_os_linux # unused variable (monkey/tests/infection_monkey/post_breach/actions/test_users_custom_pba.py:37)
|
||||||
|
|
Loading…
Reference in New Issue