diff --git a/monkey/monkey_island/cc/models/machine.py b/monkey/monkey_island/cc/models/machine.py index 359786d05..31299d169 100644 --- a/monkey/monkey_island/cc/models/machine.py +++ b/monkey/monkey_island/cc/models/machine.py @@ -13,12 +13,25 @@ MachineID: TypeAlias = PositiveInt class Machine(MutableInfectionMonkeyBaseModel): + """Represents machines, VMs, or other network nodes discovered by Infection Monkey""" + id: MachineID = Field(..., allow_mutation=False) + """Uniquely identifies the machine within the island""" + hardware_id: Optional[HardwareID] - network_interfaces: Sequence[IPv4Interface] - operating_system: OperatingSystem - operating_system_version: str - hostname: str + """An identifier generated by the agent that uniquely identifies a machine""" + + network_interfaces: Sequence[IPv4Interface] = tuple() + """The machine's networking interfaces""" + + operating_system: Optional[OperatingSystem] + """The operating system the machine is running""" + + operating_system_version: str = "" + """The specific version of the operating system the machine is running""" + + hostname: str = "" + """The hostname of the machine""" _make_immutable_sequence = validator("network_interfaces", pre=True, allow_reuse=True)( make_immutable_sequence diff --git a/monkey/monkey_island/cc/repository/__init__.py b/monkey/monkey_island/cc/repository/__init__.py index 075317b43..caef77b9b 100644 --- a/monkey/monkey_island/cc/repository/__init__.py +++ b/monkey/monkey_island/cc/repository/__init__.py @@ -1,4 +1,4 @@ -from .errors import RemovalError, RetrievalError, StorageError +from .errors import RemovalError, RetrievalError, StorageError, UnknownRecordError from .i_file_repository import FileNotFoundError, IFileRepository @@ -7,6 +7,7 @@ from .i_agent_configuration_repository import IAgentConfigurationRepository from .i_simulation_repository import ISimulationRepository from .i_credentials_repository import ICredentialsRepository from .i_user_repository import IUserRepository +from .i_machine_repository import IMachineRepository from .local_storage_file_repository import LocalStorageFileRepository @@ -19,3 +20,4 @@ from .file_agent_configuration_repository import FileAgentConfigurationRepositor from .file_simulation_repository import FileSimulationRepository from .json_file_user_repository import JSONFileUserRepository from .mongo_credentials_repository import MongoCredentialsRepository +from .mongo_machine_repository import MongoMachineRepository diff --git a/monkey/monkey_island/cc/repository/consts.py b/monkey/monkey_island/cc/repository/consts.py new file mode 100644 index 000000000..161768005 --- /dev/null +++ b/monkey/monkey_island/cc/repository/consts.py @@ -0,0 +1 @@ +MONGO_OBJECT_ID_KEY = "_id" diff --git a/monkey/monkey_island/cc/repository/errors.py b/monkey/monkey_island/cc/repository/errors.py index aeb9fa23c..a5c26fe15 100644 --- a/monkey/monkey_island/cc/repository/errors.py +++ b/monkey/monkey_island/cc/repository/errors.py @@ -3,20 +3,20 @@ class RemovalError(RuntimeError): Raised when a repository encounters an error while attempting to remove data. """ - pass - class RetrievalError(RuntimeError): """ Raised when a repository encounters an error while attempting to retrieve data. """ - pass - class StorageError(RuntimeError): """ Raised when a repository encounters an error while attempting to store data. """ - pass + +class UnknownRecordError(RuntimeError): + """ + Raised when the repository does not contain any data matching the request. + """ diff --git a/monkey/monkey_island/cc/repository/i_machine_repository.py b/monkey/monkey_island/cc/repository/i_machine_repository.py index 604d58068..d60346eb2 100644 --- a/monkey/monkey_island/cc/repository/i_machine_repository.py +++ b/monkey/monkey_island/cc/repository/i_machine_repository.py @@ -1,20 +1,75 @@ -from abc import ABC -from typing import Optional, Sequence +from abc import ABC, abstractmethod +from ipaddress import IPv4Address +from typing import Sequence + +from common.types import HardwareID +from monkey_island.cc.models import Machine, MachineID class IMachineRepository(ABC): - # TODO define Machine object(ORM model) - def save_machine(self, machine: Machine): # noqa: F821 - pass + """A repository used to store and retrieve Machines""" - # TODO define Machine object(ORM model) - # TODO define or re-use machine state. - # TODO investigate where should the state be stored in edge or both edge and machine - def get_machines( - self, - id: Optional[str] = None, - ips: Optional[Sequence[str]] = None, - state: Optional[MachineState] = None, # noqa: F821 - is_island: Optional[bool] = None, # noqa: F841 - ) -> Sequence[Machine]: # noqa: F821 - pass + @abstractmethod + def get_new_id(self) -> MachineID: + """ + Generates a new, unique `MachineID` + + :return: A new, unique `MachineID` + """ + + @abstractmethod + def upsert_machine(self, machine: Machine): + """ + Upsert (insert or update) a `Machine` + + Insert the `Machine` if no `Machine` with a matching ID exists in the repository. If the + `Machine` already exists, update it. + + :param machine: The `Machine` to be inserted or updated + :raises StorageError: If an error occurred while attempting to store the `Machine` + """ + + @abstractmethod + def get_machine_by_id(self, machine_id: MachineID) -> Machine: + """ + Get a `Machine` by ID + + :param machine_id: The ID of the `Machine` to be retrieved + :return: A `Machine` with a matching `id` + :raises UnknownRecordError: If a `Machine` with the specified `id` does not exist in the + repository + :raises RetrievalError: If an error occurred while attempting to retrieve the `Machine` + """ + + @abstractmethod + def get_machine_by_hardware_id(self, hardware_id: HardwareID) -> Machine: + """ + Get a `Machine` by ID + + :param hardware_id: The hardware ID of the `Machine` to be retrieved + :return: A `Machine` with a matching `hardware_id` + :raises UnknownRecordError: If a `Machine` with the specified `hardware_id` does not exist + in the repository + :raises RetrievalError: If an error occurred while attempting to retrieve the `Machine` + """ + + @abstractmethod + def get_machines_by_ip(self, ip: IPv4Address) -> Sequence[Machine]: + """ + Search for machines by IP address + + :param ip: The IP address to search for + :return: A sequence of Machines that have a network interface with a matching IP + :raises UnknownRecordError: If a `Machine` with the specified `ip` does not exist in the + repository + :raises RetrievalError: If an error occurred while attempting to retrieve the `Machine` + """ + + @abstractmethod + def reset(self): + """ + Removes all data from the repository + + :raises RemovalError: If an error occurred while attempting to remove all `Machines` from + the repository + """ diff --git a/monkey/monkey_island/cc/repository/mongo_credentials_repository.py b/monkey/monkey_island/cc/repository/mongo_credentials_repository.py index ee8beb13c..351cac981 100644 --- a/monkey/monkey_island/cc/repository/mongo_credentials_repository.py +++ b/monkey/monkey_island/cc/repository/mongo_credentials_repository.py @@ -3,10 +3,16 @@ from typing import Any, Dict, Mapping, Sequence from pymongo import MongoClient from common.credentials import Credentials -from monkey_island.cc.repository import RemovalError, RetrievalError, StorageError -from monkey_island.cc.repository.i_credentials_repository import ICredentialsRepository +from monkey_island.cc.repository import ( + ICredentialsRepository, + RemovalError, + RetrievalError, + StorageError, +) from monkey_island.cc.server_utils.encryption import ILockableEncryptor +from .consts import MONGO_OBJECT_ID_KEY + class MongoCredentialsRepository(ICredentialsRepository): """ @@ -51,7 +57,7 @@ class MongoCredentialsRepository(ICredentialsRepository): collection_result = [] list_collection_result = list(collection.find({})) for encrypted_credentials in list_collection_result: - del encrypted_credentials["_id"] + del encrypted_credentials[MONGO_OBJECT_ID_KEY] plaintext_credentials = self._decrypt_credentials_mapping(encrypted_credentials) collection_result.append(Credentials.from_mapping(plaintext_credentials)) diff --git a/monkey/monkey_island/cc/repository/mongo_machine_repository.py b/monkey/monkey_island/cc/repository/mongo_machine_repository.py new file mode 100644 index 000000000..4d1a36470 --- /dev/null +++ b/monkey/monkey_island/cc/repository/mongo_machine_repository.py @@ -0,0 +1,95 @@ +from ipaddress import IPv4Address +from threading import Lock +from typing import Any, MutableMapping, Sequence + +from pymongo import MongoClient + +from common.types import HardwareID +from monkey_island.cc.models import Machine, MachineID + +from . import IMachineRepository, RemovalError, RetrievalError, StorageError, UnknownRecordError +from .consts import MONGO_OBJECT_ID_KEY + + +class MongoMachineRepository(IMachineRepository): + """A repository used to store and retrieve Machines in MongoDB""" + + def __init__(self, mongo_client: MongoClient): + self._machines_collection = mongo_client.monkey_island.machines + self._id_lock = Lock() + self._next_id = self._get_biggest_id() + + def _get_biggest_id(self) -> MachineID: + try: + return self._machines_collection.find().sort("id", -1).limit(1)[0]["id"] + except IndexError: + return 0 + + def get_new_id(self) -> MachineID: + with self._id_lock: + self._next_id += 1 + return self._next_id + + def upsert_machine(self, machine: Machine): + try: + result = self._machines_collection.replace_one( + {"id": machine.id}, machine.dict(simplify=True), upsert=True + ) + except Exception as err: + raise StorageError(f'Error updating machine with ID "{machine.id}": {err}') + + if result.matched_count != 0 and result.modified_count != 1: + raise StorageError( + f'Error updating machine with ID "{machine.id}": Expected to update 1 machine, ' + f"but {result.modified_count} were updated" + ) + + if result.matched_count == 0 and result.upserted_id is None: + raise StorageError( + f'Error inserting machine with ID "{machine.id}": Expected to insert 1 machine, ' + f"but no machines were inserted" + ) + + def get_machine_by_id(self, machine_id: MachineID) -> Machine: + return self._find_one("id", machine_id) + + def get_machine_by_hardware_id(self, hardware_id: HardwareID) -> Machine: + return self._find_one("hardware_id", hardware_id) + + def _find_one(self, key: str, search_value: Any) -> Machine: + try: + machine_dict = self._machines_collection.find_one({key: search_value}) + except Exception as err: + raise RetrievalError(f'Error retrieving machine with "{key} == {search_value}": {err}') + + if machine_dict is None: + raise UnknownRecordError(f'Unknown machine with "{key} == {search_value}"') + + return MongoMachineRepository._mongo_record_to_machine(machine_dict) + + def get_machines_by_ip(self, ip: IPv4Address) -> Sequence[Machine]: + ip_regex = "^" + str(ip).replace(".", "\\.") + "\\/.*$" + query = {"network_interfaces": {"$elemMatch": {"$regex": ip_regex}}} + + try: + cursor = self._machines_collection.find(query) + except Exception as err: + raise RetrievalError(f'Error retrieving machines with ip "{ip}": {err}') + + machines = list(map(MongoMachineRepository._mongo_record_to_machine, cursor)) + + if len(machines) == 0: + raise UnknownRecordError(f'No machines found with IP "{ip}"') + + return machines + + @staticmethod + def _mongo_record_to_machine(mongo_record: MutableMapping[str, Any]) -> Machine: + del mongo_record[MONGO_OBJECT_ID_KEY] + return Machine(**mongo_record) + + def reset(self): + try: + self._machines_collection.drop() + except Exception as err: + raise RemovalError(f"Error resetting the repository: {err}") diff --git a/monkey/tests/unit_tests/monkey_island/cc/models/test_machine.py b/monkey/tests/unit_tests/monkey_island/cc/models/test_machine.py index 50cdaa61d..49eeae9a7 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/models/test_machine.py +++ b/monkey/tests/unit_tests/monkey_island/cc/models/test_machine.py @@ -84,6 +84,15 @@ def test_construct_invalid_field__value_error(key, value): Machine(**invalid_type_dict) +@pytest.mark.parametrize("field", ["hardware_id", "operating_system"]) +def test_optional_fields(field): + none_field_dict = MACHINE_SIMPLE_DICT.copy() + none_field_dict[field] = None + + # Raises exception_on_failure + Machine(**none_field_dict) + + def test_construct__extra_fields_forbidden(): extra_field_dict = MACHINE_SIMPLE_DICT.copy() extra_field_dict["extra_field"] = 99 # red balloons @@ -112,6 +121,15 @@ def test_hardware_id_validate_on_set(): m.hardware_id = -50 +def test_hardware_id_default(): + missing_hardware_id_dict = MACHINE_OBJECT_DICT.copy() + del missing_hardware_id_dict["hardware_id"] + + m = Machine(**missing_hardware_id_dict) + + assert m.hardware_id is None + + def test_network_interfaces_set_valid_value(): m = Machine(**MACHINE_OBJECT_DICT) @@ -132,6 +150,15 @@ def test_network_interfaces_sequence_is_immutable(): assert not isinstance(m.network_interfaces, MutableSequence) +def test_network_interfaces_default(): + missing_network_interfaces_dict = MACHINE_OBJECT_DICT.copy() + del missing_network_interfaces_dict["network_interfaces"] + + m = Machine(**missing_network_interfaces_dict) + + assert len(m.network_interfaces) == 0 + + def test_operating_system_set_valid_value(): m = Machine(**MACHINE_OBJECT_DICT) @@ -146,6 +173,15 @@ def test_operating_system_set_invalid_value(): m.operating_system = "MacOS" +def test_operating_system_default_value(): + missing_operating_system_dict = MACHINE_OBJECT_DICT.copy() + del missing_operating_system_dict["operating_system"] + + m = Machine(**missing_operating_system_dict) + + assert m.operating_system is None + + def test_set_operating_system_version(): m = Machine(**MACHINE_OBJECT_DICT) @@ -153,8 +189,26 @@ def test_set_operating_system_version(): m.operating_system_version = "1234" +def test_operating_system_version_default_value(): + missing_operating_system_version_dict = MACHINE_OBJECT_DICT.copy() + del missing_operating_system_version_dict["operating_system_version"] + + m = Machine(**missing_operating_system_version_dict) + + assert m.operating_system_version == "" + + def test_set_hostname(): m = Machine(**MACHINE_OBJECT_DICT) # Raises exception_on_failure m.operating_system_version = "wopr" + + +def test_hostname_default_value(): + missing_hostname_dict = MACHINE_OBJECT_DICT.copy() + del missing_hostname_dict["hostname"] + + m = Machine(**missing_hostname_dict) + + assert m.hostname == "" diff --git a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_machine_repository.py b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_machine_repository.py new file mode 100644 index 000000000..90d0af1f2 --- /dev/null +++ b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_machine_repository.py @@ -0,0 +1,258 @@ +from ipaddress import IPv4Interface +from itertools import chain, repeat +from unittest.mock import MagicMock + +import mongomock +import pytest + +from common import OperatingSystem +from monkey_island.cc.models import Machine +from monkey_island.cc.repository import ( + IMachineRepository, + MongoMachineRepository, + RemovalError, + RetrievalError, + StorageError, + UnknownRecordError, +) + +MACHINES = ( + Machine( + id=1, + hardware_id=12345, + network_interfaces=[IPv4Interface("192.168.1.10/24")], + operating_system=OperatingSystem.LINUX, + operating_system_version="Ubuntu 22.04", + hostname="wopr", + ), + Machine( + id=2, + hardware_id=67890, + network_interfaces=[IPv4Interface("192.168.1.11/24"), IPv4Interface("192.168.1.12/24")], + operating_system=OperatingSystem.WINDOWS, + operating_system_version="eXtra Problems", + hostname="hal", + ), + Machine( + id=3, + hardware_id=112345, + network_interfaces=[IPv4Interface("192.168.1.13/24"), IPv4Interface("192.168.1.14/24")], + operating_system=OperatingSystem.WINDOWS, + operating_system_version="Vista", + hostname="smith", + ), + Machine( + id=4, + hardware_id=167890, + network_interfaces=[IPv4Interface("192.168.1.14/24")], + operating_system=OperatingSystem.LINUX, + operating_system_version="CentOS Linux 8", + hostname="skynet", + ), +) + + +@pytest.fixture +def mongo_client() -> mongomock.MongoClient: + client = mongomock.MongoClient() + client.monkey_island.machines.insert_many((m.dict(simplify=True) for m in MACHINES)) + return client + + +@pytest.fixture +def error_raising_mock_mongo_client() -> mongomock.MongoClient: + mongo_client = MagicMock(spec=mongomock.MongoClient) + mongo_client.monkey_island = MagicMock(spec=mongomock.Database) + mongo_client.monkey_island.machines = MagicMock(spec=mongomock.Collection) + + # The first call to find() must succeed + mongo_client.monkey_island.machines.find = MagicMock( + side_effect=chain([MagicMock()], repeat(Exception("some exception"))) + ) + mongo_client.monkey_island.machines.find_one = MagicMock( + side_effect=Exception("some exception") + ) + mongo_client.monkey_island.machines.insert_one = MagicMock( + side_effect=Exception("some exception") + ) + mongo_client.monkey_island.machines.replace_one = MagicMock( + side_effect=Exception("some exception") + ) + mongo_client.monkey_island.machines.drop = MagicMock(side_effect=Exception("some exception")) + + return mongo_client + + +@pytest.fixture +def error_raising_machine_repository(error_raising_mock_mongo_client) -> IMachineRepository: + return MongoMachineRepository(error_raising_mock_mongo_client) + + +@pytest.fixture +def machine_repository(mongo_client) -> IMachineRepository: + return MongoMachineRepository(mongo_client) + + +def test_get_new_id__unique_id(machine_repository): + new_machine_id = machine_repository.get_new_id() + + for m in MACHINES: + assert m.id != new_machine_id + + +def test_get_new_id__multiple_unique_ids(machine_repository): + id_1 = machine_repository.get_new_id() + id_2 = machine_repository.get_new_id() + + assert id_1 != id_2 + + +def test_get_new_id__new_id_for_empty_repo(machine_repository): + empty_machine_repository = MongoMachineRepository(mongomock.MongoClient()) + id_1 = empty_machine_repository.get_new_id() + id_2 = empty_machine_repository.get_new_id() + + assert id_1 != id_2 + + +def test_upsert_machine__update(machine_repository): + machine = machine_repository.get_machine_by_id(1) + + machine.operating_system = OperatingSystem.WINDOWS + machine.hostname = "viki" + machine.network_interfaces = [IPv4Interface("10.0.0.1/16")] + + machine_repository.upsert_machine(machine) + + assert machine_repository.get_machine_by_id(1) == machine + + +def test_upsert_machine__insert(machine_repository): + new_machine = Machine(id=99, hardware_id=8675309) + + machine_repository.upsert_machine(new_machine) + + assert machine_repository.get_machine_by_id(99) == new_machine + + +def test_upsert_machine__storage_error_exception(error_raising_machine_repository): + machine = MACHINES[0] + + with pytest.raises(StorageError): + error_raising_machine_repository.upsert_machine(machine) + + +def test_upsert_machine__storage_error_update_failed(error_raising_mock_mongo_client): + mock_result = MagicMock() + mock_result.matched_count = 1 + mock_result.modified_count = 0 + + error_raising_mock_mongo_client.monkey_island.machines.replace_one = MagicMock( + return_value=mock_result + ) + machine_repository = MongoMachineRepository(error_raising_mock_mongo_client) + + machine = MACHINES[0] + with pytest.raises(StorageError): + machine_repository.upsert_machine(machine) + + +def test_upsert_machine__storage_error_insert_failed(error_raising_mock_mongo_client): + mock_result = MagicMock() + mock_result.matched_count = 0 + mock_result.upserted_id = None + + error_raising_mock_mongo_client.monkey_island.machines.replace_one = MagicMock( + return_value=mock_result + ) + machine_repository = MongoMachineRepository(error_raising_mock_mongo_client) + + machine = MACHINES[0] + with pytest.raises(StorageError): + machine_repository.upsert_machine(machine) + + +def test_get_machine_by_id(machine_repository): + for i, expected_machine in enumerate(MACHINES, start=1): + assert machine_repository.get_machine_by_id(i) == expected_machine + + +def test_get_machine_by_id__not_found(machine_repository): + with pytest.raises(UnknownRecordError): + machine_repository.get_machine_by_id(9999) + + +def test_get_machine_by_id__retrieval_error(error_raising_machine_repository): + with pytest.raises(RetrievalError): + error_raising_machine_repository.get_machine_by_id(1) + + +def test_get_machine_by_hardware_id(machine_repository): + for hardware_id, expected_machine in ((machine.hardware_id, machine) for machine in MACHINES): + assert machine_repository.get_machine_by_hardware_id(hardware_id) == expected_machine + + +def test_get_machine_by_hardware_id__not_found(machine_repository): + with pytest.raises(UnknownRecordError): + machine_repository.get_machine_by_hardware_id(9999888887777) + + +def test_get_machine_by_hardware_id__retrieval_error(error_raising_machine_repository): + with pytest.raises(RetrievalError): + error_raising_machine_repository.get_machine_by_hardware_id(1) + + +def test_get_machine_by_ip(machine_repository): + expected_machine = MACHINES[0] + expected_machine_ip = expected_machine.network_interfaces[0].ip + + retrieved_machines = machine_repository.get_machines_by_ip(expected_machine_ip) + + assert len(retrieved_machines) == 1 + assert retrieved_machines[0] == expected_machine + + +def test_get_machine_by_ip__multiple_results(machine_repository): + search_ip = MACHINES[3].network_interfaces[0].ip + + retrieved_machines = machine_repository.get_machines_by_ip(search_ip) + + assert len(retrieved_machines) == 2 + assert MACHINES[2] in retrieved_machines + assert MACHINES[3] in retrieved_machines + + +def test_get_machine_by_ip__not_found(machine_repository): + with pytest.raises(UnknownRecordError): + machine_repository.get_machines_by_ip("1.1.1.1") + + +def test_get_machine_by_ip__retrieval_error(error_raising_machine_repository): + with pytest.raises(RetrievalError): + error_raising_machine_repository.get_machines_by_ip("1.1.1.1") + + +def test_reset(machine_repository): + # Ensure the repository is not empty + preexisting_machine = machine_repository.get_machine_by_id(MACHINES[0].id) + assert isinstance(preexisting_machine, Machine) + + machine_repository.reset() + + with pytest.raises(UnknownRecordError): + machine_repository.get_machine_by_id(MACHINES[0].id) + + +def test_usable_after_reset(machine_repository): + machine_repository.reset() + + new_id = machine_repository.get_new_id() + new_machine = Machine(id=new_id) + machine_repository.upsert_machine(new_machine) + + assert new_machine == machine_repository.get_machine_by_id(new_machine.id) + + +def test_reset__removal_error(error_raising_machine_repository): + with pytest.raises(RemovalError): + error_raising_machine_repository.reset() diff --git a/vulture_allowlist.py b/vulture_allowlist.py index 5ab45088a..32d3686b1 100644 --- a/vulture_allowlist.py +++ b/vulture_allowlist.py @@ -245,8 +245,11 @@ IConfigRepository.get_config_field ILogRepository.get_logs ILogRepository.save_log ILogRepository.delete_log -IMachineRepository.save_machine -IMachineRepository.get_machines +IMachineRepository.get_new_id +IMachineRepository.upsert_machine +IMachineRepository.get_machine_by_id +IMachineRepository.get_machine_by_hardware_id +IMachineRepository.get_machines_by_ip INetworkMapRepository.get_map INetworkMapRepository.save_netmap IReportRepository