diff --git a/monkey/monkey_island/cc/models/node.py b/monkey/monkey_island/cc/models/node.py index 4b3ab5608..715e52bb3 100644 --- a/monkey/monkey_island/cc/models/node.py +++ b/monkey/monkey_island/cc/models/node.py @@ -1,4 +1,4 @@ -from typing import Mapping, Tuple +from typing import FrozenSet, Mapping from pydantic import Field from typing_extensions import TypeAlias @@ -7,7 +7,7 @@ from common.base_models import MutableInfectionMonkeyBaseModel from . import CommunicationType, MachineID -NodeConnections: TypeAlias = Mapping[MachineID, Tuple[CommunicationType, ...]] +NodeConnections: TypeAlias = Mapping[MachineID, FrozenSet[CommunicationType]] class Node(MutableInfectionMonkeyBaseModel): diff --git a/monkey/monkey_island/cc/repository/__init__.py b/monkey/monkey_island/cc/repository/__init__.py index 6bda93bbf..e8e1b5b25 100644 --- a/monkey/monkey_island/cc/repository/__init__.py +++ b/monkey/monkey_island/cc/repository/__init__.py @@ -24,3 +24,4 @@ from .json_file_user_repository import JSONFileUserRepository from .mongo_credentials_repository import MongoCredentialsRepository from .mongo_machine_repository import MongoMachineRepository from .mongo_agent_repository import MongoAgentRepository +from .mongo_node_repository import MongoNodeRepository diff --git a/monkey/monkey_island/cc/repository/i_node_repository.py b/monkey/monkey_island/cc/repository/i_node_repository.py index b8ea0c49d..d848d58fe 100644 --- a/monkey/monkey_island/cc/repository/i_node_repository.py +++ b/monkey/monkey_island/cc/repository/i_node_repository.py @@ -33,3 +33,12 @@ class INodeRepository(ABC): :return: All known Nodes :raises RetrievalError: If an error occurred while attempting to retrieve the nodes """ + + @abstractmethod + def reset(self): + """ + Removes all data from the repository + + :raises RemovalError: If an error occurred while attempting to remove all `Nodes` from the + repository + """ diff --git a/monkey/monkey_island/cc/repository/mongo_agent_repository.py b/monkey/monkey_island/cc/repository/mongo_agent_repository.py index 3b1b60f0a..dfad2bbf7 100644 --- a/monkey/monkey_island/cc/repository/mongo_agent_repository.py +++ b/monkey/monkey_island/cc/repository/mongo_agent_repository.py @@ -1,4 +1,4 @@ -from typing import Any, MutableMapping, Sequence +from typing import Sequence from pymongo import MongoClient @@ -40,27 +40,24 @@ class MongoAgentRepository(IAgentRepository): def get_agent_by_id(self, agent_id: AgentID) -> Agent: try: - agent_dict = self._agents_collection.find_one({"id": str(agent_id)}) + agent_dict = self._agents_collection.find_one( + {"id": str(agent_id)}, {MONGO_OBJECT_ID_KEY: False} + ) except Exception as err: raise RetrievalError(f'Error retrieving agent with "id == {agent_id}": {err}') if agent_dict is None: raise UnknownRecordError(f'Unknown ID "{agent_id}"') - return MongoAgentRepository._mongo_record_to_agent(agent_dict) + return Agent(**agent_dict) def get_running_agents(self) -> Sequence[Agent]: try: - cursor = self._agents_collection.find({"stop_time": None}) - return list(map(MongoAgentRepository._mongo_record_to_agent, cursor)) + cursor = self._agents_collection.find({"stop_time": None}, {MONGO_OBJECT_ID_KEY: False}) + return list(map(lambda a: Agent(**a), cursor)) except Exception as err: raise RetrievalError(f"Error retrieving running agents: {err}") - @staticmethod - def _mongo_record_to_agent(mongo_record: MutableMapping[str, Any]) -> Agent: - del mongo_record[MONGO_OBJECT_ID_KEY] - return Agent(**mongo_record) - def reset(self): try: self._agents_collection.drop() diff --git a/monkey/monkey_island/cc/repository/mongo_credentials_repository.py b/monkey/monkey_island/cc/repository/mongo_credentials_repository.py index 237f052e4..3fdc306a8 100644 --- a/monkey/monkey_island/cc/repository/mongo_credentials_repository.py +++ b/monkey/monkey_island/cc/repository/mongo_credentials_repository.py @@ -55,9 +55,8 @@ class MongoCredentialsRepository(ICredentialsRepository): def _get_credentials_from_collection(self, collection) -> Sequence[Credentials]: try: collection_result = [] - list_collection_result = list(collection.find({})) + list_collection_result = list(collection.find({}, {MONGO_OBJECT_ID_KEY: False})) for encrypted_credentials in list_collection_result: - del encrypted_credentials[MONGO_OBJECT_ID_KEY] plaintext_credentials = self._decrypt_credentials_mapping(encrypted_credentials) collection_result.append(Credentials(**plaintext_credentials)) diff --git a/monkey/monkey_island/cc/repository/mongo_machine_repository.py b/monkey/monkey_island/cc/repository/mongo_machine_repository.py index 4d1a36470..fab038694 100644 --- a/monkey/monkey_island/cc/repository/mongo_machine_repository.py +++ b/monkey/monkey_island/cc/repository/mongo_machine_repository.py @@ -1,6 +1,6 @@ from ipaddress import IPv4Address from threading import Lock -from typing import Any, MutableMapping, Sequence +from typing import Any, Sequence from pymongo import MongoClient @@ -58,36 +58,33 @@ class MongoMachineRepository(IMachineRepository): def _find_one(self, key: str, search_value: Any) -> Machine: try: - machine_dict = self._machines_collection.find_one({key: search_value}) + machine_dict = self._machines_collection.find_one( + {key: search_value}, {MONGO_OBJECT_ID_KEY: False} + ) except Exception as err: raise RetrievalError(f'Error retrieving machine with "{key} == {search_value}": {err}') if machine_dict is None: raise UnknownRecordError(f'Unknown machine with "{key} == {search_value}"') - return MongoMachineRepository._mongo_record_to_machine(machine_dict) + return Machine(**machine_dict) def get_machines_by_ip(self, ip: IPv4Address) -> Sequence[Machine]: ip_regex = "^" + str(ip).replace(".", "\\.") + "\\/.*$" query = {"network_interfaces": {"$elemMatch": {"$regex": ip_regex}}} try: - cursor = self._machines_collection.find(query) + cursor = self._machines_collection.find(query, {MONGO_OBJECT_ID_KEY: False}) except Exception as err: raise RetrievalError(f'Error retrieving machines with ip "{ip}": {err}') - machines = list(map(MongoMachineRepository._mongo_record_to_machine, cursor)) + machines = list(map(lambda m: Machine(**m), cursor)) if len(machines) == 0: raise UnknownRecordError(f'No machines found with IP "{ip}"') return machines - @staticmethod - def _mongo_record_to_machine(mongo_record: MutableMapping[str, Any]) -> Machine: - del mongo_record[MONGO_OBJECT_ID_KEY] - return Machine(**mongo_record) - def reset(self): try: self._machines_collection.drop() diff --git a/monkey/monkey_island/cc/repository/mongo_node_repository.py b/monkey/monkey_island/cc/repository/mongo_node_repository.py new file mode 100644 index 000000000..5b1b7f71f --- /dev/null +++ b/monkey/monkey_island/cc/repository/mongo_node_repository.py @@ -0,0 +1,84 @@ +from copy import deepcopy +from typing import Sequence + +from pymongo import MongoClient + +from monkey_island.cc.models import CommunicationType, MachineID, Node + +from . import INodeRepository, RemovalError, RetrievalError, StorageError +from .consts import MONGO_OBJECT_ID_KEY + +UPSERT_ERROR_MESSAGE = "An error occurred while attempting to upsert a node" +SRC_FIELD_NAME = "machine_id" + + +class MongoNodeRepository(INodeRepository): + def __init__(self, mongo_client: MongoClient): + self._nodes_collection = mongo_client.monkey_island.nodes + + def upsert_communication( + self, src: MachineID, dst: MachineID, communication_type: CommunicationType + ): + try: + node_dict = self._nodes_collection.find_one( + {SRC_FIELD_NAME: src}, {MONGO_OBJECT_ID_KEY: False} + ) + except Exception as err: + raise StorageError(f"{UPSERT_ERROR_MESSAGE}: {err}") + + if node_dict is None: + updated_node = Node(machine_id=src, connections={dst: frozenset((communication_type,))}) + else: + node = Node(**node_dict) + updated_node = MongoNodeRepository._add_connection_to_node( + node, dst, communication_type + ) + + self._upsert_node(updated_node) + + @staticmethod + def _add_connection_to_node( + node: Node, dst: MachineID, communication_type: CommunicationType + ) -> Node: + connections = dict(deepcopy(node.connections)) + communications = set(connections.get(dst, set())) + communications.add(communication_type) + connections[dst] = frozenset(communications) + + new_node = node.copy() + new_node.connections = connections + + return new_node + + def _upsert_node(self, node: Node): + try: + result = self._nodes_collection.replace_one( + {SRC_FIELD_NAME: node.machine_id}, node.dict(simplify=True), upsert=True + ) + except Exception as err: + raise StorageError(f"{UPSERT_ERROR_MESSAGE}: {err}") + + if result.matched_count != 0 and result.modified_count != 1: + raise StorageError( + f'Error updating node with source ID "{node.machine_id}": Expected to update 1 ' + f"node, but {result.modified_count} were updated" + ) + + if result.matched_count == 0 and result.upserted_id is None: + raise StorageError( + f'Error inserting node with source ID "{node.machine_id}": Expected to insert 1 ' + f"node, but no nodes were inserted" + ) + + def get_nodes(self) -> Sequence[Node]: + try: + cursor = self._nodes_collection.find({}, {MONGO_OBJECT_ID_KEY: False}) + return list(map(lambda n: Node(**n), cursor)) + except Exception as err: + raise RetrievalError(f"Error retrieving nodes from the repository: {err}") + + def reset(self): + try: + self._nodes_collection.drop() + except Exception as err: + raise RemovalError(f"Error resetting the repository: {err}") diff --git a/monkey/tests/unit_tests/monkey_island/cc/models/test_node.py b/monkey/tests/unit_tests/monkey_island/cc/models/test_node.py index e70df6975..74a83860c 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/models/test_node.py +++ b/monkey/tests/unit_tests/monkey_island/cc/models/test_node.py @@ -8,8 +8,8 @@ from monkey_island.cc.models import CommunicationType, Node def test_constructor(): machine_id = 1 connections = { - 6: (CommunicationType.SCANNED,), - 7: (CommunicationType.SCANNED, CommunicationType.EXPLOITED), + 6: frozenset((CommunicationType.SCANNED,)), + 7: frozenset((CommunicationType.SCANNED, CommunicationType.EXPLOITED)), } n = Node( machine_id=1, @@ -24,13 +24,25 @@ def test_serialization(): node_dict = { "machine_id": 1, "connections": { - "6": ["cc", "scanned"], - "7": ["exploited", "cc"], + "6": [CommunicationType.CC.value, CommunicationType.SCANNED.value], + "7": [CommunicationType.EXPLOITED.value, CommunicationType.CC.value], }, } + # "6": frozenset((CommunicationType.CC, CommunicationType.SCANNED)), + # "7": frozenset((CommunicationType.EXPLOITED, CommunicationType.CC)), n = Node(**node_dict) - assert n.dict(simplify=True) == node_dict + serialized_node = n.dict(simplify=True) + + # NOTE: Comparing these nodes is difficult because sets are not ordered + assert len(serialized_node) == len(node_dict) + for key in serialized_node.keys(): + assert key in node_dict + + assert len(serialized_node["connections"]) == len(node_dict["connections"]) + + for key, value in serialized_node["connections"].items(): + assert set(value) == set(node_dict["connections"][key]) def test_machine_id_immutable(): diff --git a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_node_repository.py b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_node_repository.py new file mode 100644 index 000000000..b056cb575 --- /dev/null +++ b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_node_repository.py @@ -0,0 +1,214 @@ +from unittest.mock import MagicMock + +import mongomock +import pytest + +from monkey_island.cc.models import CommunicationType, Node +from monkey_island.cc.repository import ( + INodeRepository, + MongoNodeRepository, + RemovalError, + RetrievalError, + StorageError, +) + +NODES = ( + Node( + machine_id=1, + connections={ + 2: frozenset((CommunicationType.SCANNED,)), + 3: frozenset((CommunicationType.SCANNED, CommunicationType.EXPLOITED)), + }, + ), + Node( + machine_id=2, + connections={1: frozenset((CommunicationType.CC,))}, + ), + Node( + machine_id=3, + connections={ + 1: frozenset((CommunicationType.CC,)), + 4: frozenset((CommunicationType.SCANNED,)), + 5: frozenset((CommunicationType.SCANNED, CommunicationType.EXPLOITED)), + }, + ), + Node( + machine_id=4, + connections={}, + ), + Node( + machine_id=5, + connections={ + 2: frozenset((CommunicationType.SCANNED, CommunicationType.EXPLOITED)), + 3: frozenset((CommunicationType.CC,)), + }, + ), +) + + +@pytest.fixture +def empty_node_repository() -> INodeRepository: + return MongoNodeRepository(mongomock.MongoClient()) + + +@pytest.fixture +def mongo_client() -> mongomock.MongoClient: + client = mongomock.MongoClient() + client.monkey_island.nodes.insert_many((n.dict(simplify=True) for n in NODES)) + return client + + +@pytest.fixture +def node_repository(mongo_client) -> INodeRepository: + return MongoNodeRepository(mongo_client) + + +@pytest.fixture +def error_raising_mock_mongo_client() -> mongomock.MongoClient: + mongo_client = MagicMock(spec=mongomock.MongoClient) + mongo_client.monkey_island = MagicMock(spec=mongomock.Database) + mongo_client.monkey_island.nodes = MagicMock(spec=mongomock.Collection) + + mongo_client.monkey_island.nodes.find = MagicMock(side_effect=Exception("some exception")) + mongo_client.monkey_island.nodes.find_one = MagicMock(side_effect=Exception("some exception")) + mongo_client.monkey_island.nodes.replace_one = MagicMock( + side_effect=Exception("some exception") + ) + mongo_client.monkey_island.nodes.drop = MagicMock(side_effect=Exception("some exception")) + + return mongo_client + + +@pytest.fixture +def error_raising_node_repository(error_raising_mock_mongo_client) -> INodeRepository: + return MongoNodeRepository(error_raising_mock_mongo_client) + + +def test_upsert_communication__empty_repository(empty_node_repository): + src_machine_id = 1 + dst_machine_id = 2 + expected_node = Node( + machine_id=src_machine_id, + connections={dst_machine_id: frozenset((CommunicationType.SCANNED,))}, + ) + + empty_node_repository.upsert_communication( + src_machine_id, dst_machine_id, CommunicationType.SCANNED + ) + nodes = empty_node_repository.get_nodes() + + assert len(nodes) == 1 + assert nodes[0] == expected_node + + +def test_upsert_communication__new_node(node_repository): + src_machine_id = NODES[-1].machine_id + 100 + dst_machine_id = 1 + expected_nodes = NODES + ( + Node( + machine_id=src_machine_id, + connections={dst_machine_id: frozenset((CommunicationType.CC,))}, + ), + ) + node_repository.upsert_communication(src_machine_id, dst_machine_id, CommunicationType.CC) + nodes = node_repository.get_nodes() + + assert len(nodes) == len(expected_nodes) + for en in expected_nodes: + assert en in nodes + + +def test_upsert_communication__update_existing_connection(node_repository): + src_machine_id = 1 + dst_machine_id = 2 + expected_node = NODES[0].copy(deep=True) + expected_node.connections[2] = frozenset( + (*expected_node.connections[2], CommunicationType.EXPLOITED) + ) + node_repository.upsert_communication( + src_machine_id, dst_machine_id, CommunicationType.EXPLOITED + ) + nodes = node_repository.get_nodes() + + for node in nodes: + if node.machine_id == src_machine_id: + assert node == expected_node + break + + +def test_upsert_communication__update_existing_node_add_connection(node_repository): + src_machine_id = 2 + dst_machine_id = 5 + expected_node = NODES[1].copy(deep=True) + expected_node.connections[5] = frozenset((CommunicationType.SCANNED,)) + node_repository.upsert_communication(src_machine_id, dst_machine_id, CommunicationType.SCANNED) + nodes = node_repository.get_nodes() + + for node in nodes: + if node.machine_id == src_machine_id: + assert node == expected_node + break + + +def test_upsert_communication__find_one_fails(error_raising_node_repository): + with pytest.raises(StorageError): + error_raising_node_repository.upsert_communication(1, 2, CommunicationType.SCANNED) + + +def test_upsert_communication__replace_one_fails( + error_raising_mock_mongo_client, error_raising_node_repository +): + error_raising_mock_mongo_client.monkey_island.nodes.find_one = lambda _: None + with pytest.raises(StorageError): + error_raising_node_repository.upsert_communication(1, 2, CommunicationType.SCANNED) + + +def test_upsert_communication__replace_one_matched_without_modify( + error_raising_mock_mongo_client, error_raising_node_repository +): + mock_result = MagicMock() + mock_result.matched_count = 1 + mock_result.modified_count = 0 + error_raising_mock_mongo_client.monkey_island.nodes.find_one = lambda _: None + error_raising_mock_mongo_client.monkey_island.nodes.replace_one = lambda *_, **__: mock_result + + with pytest.raises(StorageError): + error_raising_node_repository.upsert_communication(1, 2, CommunicationType.SCANNED) + + +def test_upsert_communication__replace_one_insert_fails( + error_raising_mock_mongo_client, error_raising_node_repository +): + mock_result = MagicMock() + mock_result.matched_count = 0 + mock_result.upserted_id = None + error_raising_mock_mongo_client.monkey_island.nodes.find_one = lambda _: None + error_raising_mock_mongo_client.monkey_island.nodes.replace_one = lambda *_, **__: mock_result + + with pytest.raises(StorageError): + error_raising_node_repository.upsert_communication(1, 2, CommunicationType.SCANNED) + + +def test_get_nodes(node_repository): + nodes = node_repository.get_nodes() + assert len(nodes) == len(nodes) + for n in nodes: + assert n in NODES + + +def test_get_nodes__find_fails(error_raising_node_repository): + with pytest.raises(RetrievalError): + error_raising_node_repository.get_nodes() + + +def test_reset(node_repository): + assert len(node_repository.get_nodes()) > 0 + + node_repository.reset() + + assert len(node_repository.get_nodes()) == 0 + + +def test_reset__removal_error(error_raising_node_repository): + with pytest.raises(RemovalError): + error_raising_node_repository.reset() diff --git a/vulture_allowlist.py b/vulture_allowlist.py index 027a9e27e..34f81abb1 100644 --- a/vulture_allowlist.py +++ b/vulture_allowlist.py @@ -12,6 +12,7 @@ from infection_monkey.exploit.log4shell_utils.ldap_server import LDAPServerFacto from monkey_island.cc.event_queue import IslandEventTopic, PyPubSubIslandEventQueue from monkey_island.cc.models import Report from monkey_island.cc.models.networkmap import Arc, NetworkMap +from monkey_island.cc.repository import MongoAgentRepository, MongoMachineRepository from monkey_island.cc.repository.attack.IMitigationsRepository import IMitigationsRepository from monkey_island.cc.repository.i_agent_repository import IAgentRepository from monkey_island.cc.repository.i_attack_repository import IAttackRepository @@ -275,6 +276,8 @@ ICredentialsRepository.save_stolen_credentials ICredentialsRepository.save_configured_credentials IEventRepository.get_events IFindingRepository.get_findings +MongoAgentRepository +MongoMachineRepository key_list simulation netmap