diff --git a/monkey/monkey_island/cc/repository/i_machine_repository.py b/monkey/monkey_island/cc/repository/i_machine_repository.py index 568564009..d60346eb2 100644 --- a/monkey/monkey_island/cc/repository/i_machine_repository.py +++ b/monkey/monkey_island/cc/repository/i_machine_repository.py @@ -10,21 +10,22 @@ class IMachineRepository(ABC): """A repository used to store and retrieve Machines""" @abstractmethod - def create_machine(self) -> Machine: + def get_new_id(self) -> MachineID: """ - Create a new `Machine` in the repository + Generates a new, unique `MachineID` - :return: A new `Machine` with a unique ID - :raises StorageError: If a new `Machine` could not be created + :return: A new, unique `MachineID` """ @abstractmethod - def update_machine(self, machine: Machine): + def upsert_machine(self, machine: Machine): """ - Update an existing `Machine` in the repository + Upsert (insert or update) a `Machine` - :param machine: An updated Machine object to store in the repository - :raises UnknownRecordError: If the provided `Machine` does not exist in the repository + Insert the `Machine` if no `Machine` with a matching ID exists in the repository. If the + `Machine` already exists, update it. + + :param machine: The `Machine` to be inserted or updated :raises StorageError: If an error occurred while attempting to store the `Machine` """ diff --git a/monkey/monkey_island/cc/repository/mongo_machine_repository.py b/monkey/monkey_island/cc/repository/mongo_machine_repository.py index cea459c25..4d1a36470 100644 --- a/monkey/monkey_island/cc/repository/mongo_machine_repository.py +++ b/monkey/monkey_island/cc/repository/mongo_machine_repository.py @@ -25,36 +25,29 @@ class MongoMachineRepository(IMachineRepository): except IndexError: return 0 - def create_machine(self) -> Machine: - try: - next_id = self._get_next_id() - new_machine = Machine(id=next_id) - self._machines_collection.insert_one(new_machine.dict(simplify=True)) - - return new_machine - except Exception as err: - raise StorageError(f"Error creating a new machine: {err}") - - def _get_next_id(self) -> MachineID: + def get_new_id(self) -> MachineID: with self._id_lock: self._next_id += 1 return self._next_id - def update_machine(self, machine: Machine): + def upsert_machine(self, machine: Machine): try: result = self._machines_collection.replace_one( - {"id": machine.id}, machine.dict(simplify=True) + {"id": machine.id}, machine.dict(simplify=True), upsert=True ) except Exception as err: raise StorageError(f'Error updating machine with ID "{machine.id}": {err}') - if result.matched_count == 0: - raise UnknownRecordError(f"Unknown machine: id == {machine.id}") - - if result.modified_count != 1: + if result.matched_count != 0 and result.modified_count != 1: raise StorageError( - f'Error updating machine with ID "{machine.id}": Expected to update 1 machines, ' - f"but updated {result.modified_count}" + f'Error updating machine with ID "{machine.id}": Expected to update 1 machine, ' + f"but {result.modified_count} were updated" + ) + + if result.matched_count == 0 and result.upserted_id is None: + raise StorageError( + f'Error inserting machine with ID "{machine.id}": Expected to insert 1 machine, ' + f"but no machines were inserted" ) def get_machine_by_id(self, machine_id: MachineID) -> Machine: diff --git a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_machine_repository.py b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_machine_repository.py index 7851f8e20..90d0af1f2 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_machine_repository.py +++ b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_machine_repository.py @@ -93,60 +93,56 @@ def machine_repository(mongo_client) -> IMachineRepository: return MongoMachineRepository(mongo_client) -def test_create_machine__unique_id(machine_repository): - new_machine = machine_repository.create_machine() +def test_get_new_id__unique_id(machine_repository): + new_machine_id = machine_repository.get_new_id() for m in MACHINES: - assert m.id != new_machine.id + assert m.id != new_machine_id -def test_create_machine__multiple_unique_ids(machine_repository): - new_machine_1 = machine_repository.create_machine() - new_machine_2 = machine_repository.create_machine() +def test_get_new_id__multiple_unique_ids(machine_repository): + id_1 = machine_repository.get_new_id() + id_2 = machine_repository.get_new_id() - assert new_machine_1.id != new_machine_2.id + assert id_1 != id_2 -def test_create_machine__new_id_for_empty_repo(machine_repository): +def test_get_new_id__new_id_for_empty_repo(machine_repository): empty_machine_repository = MongoMachineRepository(mongomock.MongoClient()) - new_machine_1 = empty_machine_repository.create_machine() - new_machine_2 = empty_machine_repository.create_machine() + id_1 = empty_machine_repository.get_new_id() + id_2 = empty_machine_repository.get_new_id() - assert new_machine_1.id != new_machine_2.id + assert id_1 != id_2 -def test_create_machine__storage_error(error_raising_machine_repository): - with pytest.raises(StorageError): - error_raising_machine_repository.create_machine() - - -def test_update_machine(machine_repository): +def test_upsert_machine__update(machine_repository): machine = machine_repository.get_machine_by_id(1) machine.operating_system = OperatingSystem.WINDOWS machine.hostname = "viki" machine.network_interfaces = [IPv4Interface("10.0.0.1/16")] - machine_repository.update_machine(machine) + machine_repository.upsert_machine(machine) assert machine_repository.get_machine_by_id(1) == machine -def test_update_machine__not_found(machine_repository): - machine = Machine(id=99) +def test_upsert_machine__insert(machine_repository): + new_machine = Machine(id=99, hardware_id=8675309) - with pytest.raises(UnknownRecordError): - machine_repository.update_machine(machine) + machine_repository.upsert_machine(new_machine) + + assert machine_repository.get_machine_by_id(99) == new_machine -def test_update_machine__storage_error_exception(error_raising_machine_repository): +def test_upsert_machine__storage_error_exception(error_raising_machine_repository): machine = MACHINES[0] with pytest.raises(StorageError): - error_raising_machine_repository.update_machine(machine) + error_raising_machine_repository.upsert_machine(machine) -def test_update_machine__storage_error_update_failed(error_raising_mock_mongo_client): +def test_upsert_machine__storage_error_update_failed(error_raising_mock_mongo_client): mock_result = MagicMock() mock_result.matched_count = 1 mock_result.modified_count = 0 @@ -158,7 +154,22 @@ def test_update_machine__storage_error_update_failed(error_raising_mock_mongo_cl machine = MACHINES[0] with pytest.raises(StorageError): - machine_repository.update_machine(machine) + machine_repository.upsert_machine(machine) + + +def test_upsert_machine__storage_error_insert_failed(error_raising_mock_mongo_client): + mock_result = MagicMock() + mock_result.matched_count = 0 + mock_result.upserted_id = None + + error_raising_mock_mongo_client.monkey_island.machines.replace_one = MagicMock( + return_value=mock_result + ) + machine_repository = MongoMachineRepository(error_raising_mock_mongo_client) + + machine = MACHINES[0] + with pytest.raises(StorageError): + machine_repository.upsert_machine(machine) def test_get_machine_by_id(machine_repository): @@ -235,7 +246,9 @@ def test_reset(machine_repository): def test_usable_after_reset(machine_repository): machine_repository.reset() - new_machine = machine_repository.create_machine() + new_id = machine_repository.get_new_id() + new_machine = Machine(id=new_id) + machine_repository.upsert_machine(new_machine) assert new_machine == machine_repository.get_machine_by_id(new_machine.id) diff --git a/vulture_allowlist.py b/vulture_allowlist.py index 9da5fee84..688d2f123 100644 --- a/vulture_allowlist.py +++ b/vulture_allowlist.py @@ -244,8 +244,8 @@ IConfigRepository.get_config_field ILogRepository.get_logs ILogRepository.save_log ILogRepository.delete_log -IMachineRepository.create_machine -IMachineRepository.update_machine +IMachineRepository.get_new_id +IMachineRepository.upsert_machine IMachineRepository.get_machine_by_id IMachineRepository.get_machine_by_hardware_id IMachineRepository.get_machines_by_ip