From 0adf9d84677de77cb2c85915a3d3f1cb91f2c91b Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Tue, 30 Aug 2022 05:03:47 -0400 Subject: [PATCH] Island: Add MongoMachineRepository --- .../monkey_island/cc/repository/__init__.py | 1 + .../cc/repository/mongo_machine_repository.py | 99 ++++++++ .../test_mongo_machine_repository.py | 238 ++++++++++++++++++ 3 files changed, 338 insertions(+) create mode 100644 monkey/monkey_island/cc/repository/mongo_machine_repository.py create mode 100644 monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_machine_repository.py diff --git a/monkey/monkey_island/cc/repository/__init__.py b/monkey/monkey_island/cc/repository/__init__.py index d9987e677..caef77b9b 100644 --- a/monkey/monkey_island/cc/repository/__init__.py +++ b/monkey/monkey_island/cc/repository/__init__.py @@ -20,3 +20,4 @@ from .file_agent_configuration_repository import FileAgentConfigurationRepositor from .file_simulation_repository import FileSimulationRepository from .json_file_user_repository import JSONFileUserRepository from .mongo_credentials_repository import MongoCredentialsRepository +from .mongo_machine_repository import MongoMachineRepository diff --git a/monkey/monkey_island/cc/repository/mongo_machine_repository.py b/monkey/monkey_island/cc/repository/mongo_machine_repository.py new file mode 100644 index 000000000..ba638e53a --- /dev/null +++ b/monkey/monkey_island/cc/repository/mongo_machine_repository.py @@ -0,0 +1,99 @@ +from ipaddress import IPv4Address +from threading import Lock +from typing import Any, MutableMapping, Sequence + +from pymongo import MongoClient + +from common.types import HardwareID +from monkey_island.cc.models import Machine, MachineID + +from . import IMachineRepository, RetrievalError, StorageError, UnknownRecordError +from .consts import MONGO_OBJECT_ID_KEY + + +class MongoMachineRepository(IMachineRepository): + """A repository used to store and retrieve Machines in MongoDB""" + + def __init__(self, mongo_client: MongoClient): + self._machines_collection = mongo_client.monkey_island.machines + self._id_lock = Lock() + self._next_id = self._get_biggest_id() + + def _get_biggest_id(self) -> MachineID: + try: + return self._machines_collection.find().sort("id", -1).limit(1)[0]["id"] + except IndexError: + return 0 + + def create_machine(self) -> Machine: + try: + next_id = self._get_next_id() + new_machine = Machine(id=next_id) + self._machines_collection.insert_one(new_machine.dict(simplify=True)) + + return new_machine + except Exception as err: + raise StorageError(f"Error creating a new machine: {err}") + + def _get_next_id(self) -> MachineID: + with self._id_lock: + self._next_id += 1 + return self._next_id + + def update_machine(self, machine: Machine): + try: + result = self._machines_collection.replace_one( + {"id": machine.id}, machine.dict(simplify=True) + ) + except Exception as err: + raise StorageError(f'Error updating machine with ID "{machine.id}": {err}') + + if result.matched_count == 0: + raise UnknownRecordError(f"Unknown machine: id == {machine.id}") + + if result.modified_count != 1: + raise StorageError( + f'Error updating machine with ID "{machine.id}": Expected to update 1 machines, ' + f"but updated {result.modified_count}" + ) + + def get_machine_by_id(self, id_: MachineID) -> Machine: + return self._find_one("id", "machine ID", id_) + + def get_machine_by_hardware_id(self, hardware_id: HardwareID) -> Machine: + return self._find_one("hardware_id", "hardware ID", hardware_id) + + def _find_one(self, key: str, key_display_name: str, search_value: Any) -> Machine: + try: + machine_dict = self._machines_collection.find_one({key: search_value}) + except Exception as err: + raise RetrievalError(f'Error retrieving machine with "{key} == {search_value}": {err}') + + if machine_dict is None: + raise UnknownRecordError(f'Unknown {key_display_name} "{search_value}"') + + return MongoMachineRepository._mongo_record_to_machine(machine_dict) + + def get_machines_by_ip(self, ip: IPv4Address) -> Sequence[Machine]: + ip_regex = "^" + str(ip).replace(".", "\\.") + "\\/.*$" + query = {"network_interfaces": {"$elemMatch": {"$regex": ip_regex}}} + + try: + cursor = self._machines_collection.find(query) + except Exception as err: + raise RetrievalError(f'Error retrieving machines with ip "{ip}": {err}') + + machines = list(map(MongoMachineRepository._mongo_record_to_machine, cursor)) + + if len(machines) == 0: + raise UnknownRecordError(f'No machines found with IP "{ip}"') + + return machines + + @staticmethod + def _mongo_record_to_machine(mongo_record: MutableMapping[str, Any]) -> Machine: + del mongo_record[MONGO_OBJECT_ID_KEY] + return Machine(**mongo_record) + + def reset(self): + self._machines_collection.drop() diff --git a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_machine_repository.py b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_machine_repository.py new file mode 100644 index 000000000..9ff00c64b --- /dev/null +++ b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_machine_repository.py @@ -0,0 +1,238 @@ +from ipaddress import IPv4Interface +from itertools import chain, repeat +from unittest.mock import MagicMock + +import mongomock +import pytest + +from common import OperatingSystem +from monkey_island.cc.models import Machine +from monkey_island.cc.repository import ( + IMachineRepository, + MongoMachineRepository, + RetrievalError, + StorageError, + UnknownRecordError, +) + +MACHINES = ( + Machine( + id=1, + hardware_id=12345, + network_interfaces=[IPv4Interface("192.168.1.10/24")], + operating_system=OperatingSystem.LINUX, + operating_system_version="Ubuntu 22.04", + hostname="wopr", + ), + Machine( + id=2, + hardware_id=67890, + network_interfaces=[IPv4Interface("192.168.1.11/24"), IPv4Interface("192.168.1.12/24")], + operating_system=OperatingSystem.WINDOWS, + operating_system_version="eXtra Problems", + hostname="hal", + ), + Machine( + id=3, + hardware_id=112345, + network_interfaces=[IPv4Interface("192.168.1.13/24"), IPv4Interface("192.168.1.14/24")], + operating_system=OperatingSystem.WINDOWS, + operating_system_version="Vista", + hostname="smith", + ), + Machine( + id=4, + hardware_id=167890, + network_interfaces=[IPv4Interface("192.168.1.14/24")], + operating_system=OperatingSystem.LINUX, + operating_system_version="CentOS Linux 8", + hostname="skynet", + ), +) + + +@pytest.fixture +def mongo_client() -> mongomock.MongoClient: + client = mongomock.MongoClient() + client.monkey_island.machines.insert_many((m.dict(simplify=True) for m in MACHINES)) + return client + + +@pytest.fixture +def error_raising_mock_mongo_client() -> mongomock.MongoClient: + mongo_client = MagicMock(spec=mongomock.MongoClient) + mongo_client.monkey_island = MagicMock(spec=mongomock.Database) + mongo_client.monkey_island.machines = MagicMock(spec=mongomock.Collection) + + # The first call to find() must succeed + mongo_client.monkey_island.machines.find = MagicMock( + side_effect=chain([MagicMock()], repeat(Exception("some exception"))) + ) + mongo_client.monkey_island.machines.find_one = MagicMock( + side_effect=Exception("some exception") + ) + mongo_client.monkey_island.machines.insert_one = MagicMock( + side_effect=Exception("some exception") + ) + mongo_client.monkey_island.machines.replace_one = MagicMock( + side_effect=Exception("some exception") + ) + + return mongo_client + + +@pytest.fixture +def error_raising_machine_repository(error_raising_mock_mongo_client) -> IMachineRepository: + return MongoMachineRepository(error_raising_mock_mongo_client) + + +@pytest.fixture +def machine_repository(mongo_client) -> IMachineRepository: + return MongoMachineRepository(mongo_client) + + +def test_create_machine__unique_id(machine_repository): + new_machine = machine_repository.create_machine() + + for m in MACHINES: + assert m.id != new_machine.id + + +def test_create_machine__multiple_unique_ids(machine_repository): + new_machine_1 = machine_repository.create_machine() + new_machine_2 = machine_repository.create_machine() + + assert new_machine_1.id != new_machine_2.id + + +def test_create_machine__new_id_for_empty_repo(machine_repository): + empty_machine_repository = MongoMachineRepository(mongomock.MongoClient()) + new_machine_1 = empty_machine_repository.create_machine() + new_machine_2 = empty_machine_repository.create_machine() + + assert new_machine_1.id != new_machine_2.id + + +def test_create_machine__storage_error(error_raising_machine_repository): + with pytest.raises(StorageError): + error_raising_machine_repository.create_machine() + + +def test_update_machine(machine_repository): + machine = machine_repository.get_machine_by_id(1) + + machine.operating_system = OperatingSystem.WINDOWS + machine.hostname = "viki" + machine.network_interfaces = [IPv4Interface("10.0.0.1/16")] + + machine_repository.update_machine(machine) + + assert machine_repository.get_machine_by_id(1) == machine + + +def test_update_machine__not_found(machine_repository): + machine = Machine(id=99) + + with pytest.raises(UnknownRecordError): + machine_repository.update_machine(machine) + + +def test_update_machine__storage_error_exception(error_raising_machine_repository): + machine = MACHINES[0] + + with pytest.raises(StorageError): + error_raising_machine_repository.update_machine(machine) + + +def test_update_machine__storage_error_update_failed(error_raising_mock_mongo_client): + mock_result = MagicMock() + mock_result.matched_count = 1 + mock_result.modified_count = 0 + + error_raising_mock_mongo_client.monkey_island.machines.replace_one = MagicMock( + return_value=mock_result + ) + machine_repository = MongoMachineRepository(error_raising_mock_mongo_client) + + machine = MACHINES[0] + with pytest.raises(StorageError): + machine_repository.update_machine(machine) + + +def test_get_machine_by_id(machine_repository): + for i, expected_machine in enumerate(MACHINES, start=1): + assert machine_repository.get_machine_by_id(i) == expected_machine + + +def test_get_machine_by_id__not_found(machine_repository): + with pytest.raises(UnknownRecordError): + machine_repository.get_machine_by_id(9999) + + +def test_get_machine_by_id__retrieval_error(error_raising_machine_repository): + with pytest.raises(RetrievalError): + error_raising_machine_repository.get_machine_by_id(1) + + +def test_get_machine_by_hardware_id(machine_repository): + for hardware_id, expected_machine in ((machine.hardware_id, machine) for machine in MACHINES): + assert machine_repository.get_machine_by_hardware_id(hardware_id) == expected_machine + + +def test_get_machine_by_hardware_id__not_found(machine_repository): + with pytest.raises(UnknownRecordError): + machine_repository.get_machine_by_hardware_id(9999888887777) + + +def test_get_machine_by_hardware_id__retrieval_error(error_raising_machine_repository): + with pytest.raises(RetrievalError): + error_raising_machine_repository.get_machine_by_hardware_id(1) + + +def test_get_machine_by_ip(machine_repository): + expected_machine = MACHINES[0] + expected_machine_ip = expected_machine.network_interfaces[0].ip + + retrieved_machines = machine_repository.get_machines_by_ip(expected_machine_ip) + + assert len(retrieved_machines) == 1 + assert retrieved_machines[0] == expected_machine + + +def test_get_machine_by_ip__multiple_results(machine_repository): + search_ip = MACHINES[3].network_interfaces[0].ip + + retrieved_machines = machine_repository.get_machines_by_ip(search_ip) + + assert len(retrieved_machines) == 2 + assert MACHINES[2] in retrieved_machines + assert MACHINES[3] in retrieved_machines + + +def test_get_machine_by_ip__not_found(machine_repository): + with pytest.raises(UnknownRecordError): + machine_repository.get_machines_by_ip("1.1.1.1") + + +def test_get_machine_by_ip__retrieval_error(error_raising_machine_repository): + with pytest.raises(RetrievalError): + error_raising_machine_repository.get_machines_by_ip("1.1.1.1") + + +def test_reset(machine_repository): + # Ensure the repository is not empty + preexisting_machine = machine_repository.get_machine_by_id(MACHINES[0].id) + assert isinstance(preexisting_machine, Machine) + + machine_repository.reset() + + with pytest.raises(UnknownRecordError): + machine_repository.get_machine_by_id(MACHINES[0].id) + + +def test_usable_after_reset(machine_repository): + machine_repository.reset() + + new_machine = machine_repository.create_machine() + + assert new_machine == machine_repository.get_machine_by_id(new_machine.id)