Island: Refactor how Machine objects are managed by IMachineRepository

- Replace `create_machine()` with `get_new_id()`
- Replace `update_machine()` with `upsert_machine()`

Benefits:
    The repository doesn't store Machine objects that only have the ID
    populated (unless that is the caller's desire).

    Upsert instead of update allows the interface to be more permissive.
This commit is contained in:
Mike Salvatore 2022-08-31 10:14:57 -04:00
parent 81128a4842
commit ba7dab26d7
4 changed files with 63 additions and 56 deletions

View File

@ -10,21 +10,22 @@ class IMachineRepository(ABC):
"""A repository used to store and retrieve Machines""" """A repository used to store and retrieve Machines"""
@abstractmethod @abstractmethod
def create_machine(self) -> Machine: def get_new_id(self) -> MachineID:
""" """
Create a new `Machine` in the repository Generates a new, unique `MachineID`
:return: A new `Machine` with a unique ID :return: A new, unique `MachineID`
:raises StorageError: If a new `Machine` could not be created
""" """
@abstractmethod @abstractmethod
def update_machine(self, machine: Machine): def upsert_machine(self, machine: Machine):
""" """
Update an existing `Machine` in the repository Upsert (insert or update) a `Machine`
:param machine: An updated Machine object to store in the repository Insert the `Machine` if no `Machine` with a matching ID exists in the repository. If the
:raises UnknownRecordError: If the provided `Machine` does not exist in the repository `Machine` already exists, update it.
:param machine: The `Machine` to be inserted or updated
:raises StorageError: If an error occurred while attempting to store the `Machine` :raises StorageError: If an error occurred while attempting to store the `Machine`
""" """

View File

@ -25,36 +25,29 @@ class MongoMachineRepository(IMachineRepository):
except IndexError: except IndexError:
return 0 return 0
def create_machine(self) -> Machine: def get_new_id(self) -> MachineID:
try:
next_id = self._get_next_id()
new_machine = Machine(id=next_id)
self._machines_collection.insert_one(new_machine.dict(simplify=True))
return new_machine
except Exception as err:
raise StorageError(f"Error creating a new machine: {err}")
def _get_next_id(self) -> MachineID:
with self._id_lock: with self._id_lock:
self._next_id += 1 self._next_id += 1
return self._next_id return self._next_id
def update_machine(self, machine: Machine): def upsert_machine(self, machine: Machine):
try: try:
result = self._machines_collection.replace_one( result = self._machines_collection.replace_one(
{"id": machine.id}, machine.dict(simplify=True) {"id": machine.id}, machine.dict(simplify=True), upsert=True
) )
except Exception as err: except Exception as err:
raise StorageError(f'Error updating machine with ID "{machine.id}": {err}') raise StorageError(f'Error updating machine with ID "{machine.id}": {err}')
if result.matched_count == 0: if result.matched_count != 0 and result.modified_count != 1:
raise UnknownRecordError(f"Unknown machine: id == {machine.id}")
if result.modified_count != 1:
raise StorageError( raise StorageError(
f'Error updating machine with ID "{machine.id}": Expected to update 1 machines, ' f'Error updating machine with ID "{machine.id}": Expected to update 1 machine, '
f"but updated {result.modified_count}" f"but {result.modified_count} were updated"
)
if result.matched_count == 0 and result.upserted_id is None:
raise StorageError(
f'Error inserting machine with ID "{machine.id}": Expected to insert 1 machine, '
f"but no machines were inserted"
) )
def get_machine_by_id(self, machine_id: MachineID) -> Machine: def get_machine_by_id(self, machine_id: MachineID) -> Machine:

View File

@ -93,60 +93,56 @@ def machine_repository(mongo_client) -> IMachineRepository:
return MongoMachineRepository(mongo_client) return MongoMachineRepository(mongo_client)
def test_create_machine__unique_id(machine_repository): def test_get_new_id__unique_id(machine_repository):
new_machine = machine_repository.create_machine() new_machine_id = machine_repository.get_new_id()
for m in MACHINES: for m in MACHINES:
assert m.id != new_machine.id assert m.id != new_machine_id
def test_create_machine__multiple_unique_ids(machine_repository): def test_get_new_id__multiple_unique_ids(machine_repository):
new_machine_1 = machine_repository.create_machine() id_1 = machine_repository.get_new_id()
new_machine_2 = machine_repository.create_machine() id_2 = machine_repository.get_new_id()
assert new_machine_1.id != new_machine_2.id assert id_1 != id_2
def test_create_machine__new_id_for_empty_repo(machine_repository): def test_get_new_id__new_id_for_empty_repo(machine_repository):
empty_machine_repository = MongoMachineRepository(mongomock.MongoClient()) empty_machine_repository = MongoMachineRepository(mongomock.MongoClient())
new_machine_1 = empty_machine_repository.create_machine() id_1 = empty_machine_repository.get_new_id()
new_machine_2 = empty_machine_repository.create_machine() id_2 = empty_machine_repository.get_new_id()
assert new_machine_1.id != new_machine_2.id assert id_1 != id_2
def test_create_machine__storage_error(error_raising_machine_repository): def test_upsert_machine__update(machine_repository):
with pytest.raises(StorageError):
error_raising_machine_repository.create_machine()
def test_update_machine(machine_repository):
machine = machine_repository.get_machine_by_id(1) machine = machine_repository.get_machine_by_id(1)
machine.operating_system = OperatingSystem.WINDOWS machine.operating_system = OperatingSystem.WINDOWS
machine.hostname = "viki" machine.hostname = "viki"
machine.network_interfaces = [IPv4Interface("10.0.0.1/16")] machine.network_interfaces = [IPv4Interface("10.0.0.1/16")]
machine_repository.update_machine(machine) machine_repository.upsert_machine(machine)
assert machine_repository.get_machine_by_id(1) == machine assert machine_repository.get_machine_by_id(1) == machine
def test_update_machine__not_found(machine_repository): def test_upsert_machine__insert(machine_repository):
machine = Machine(id=99) new_machine = Machine(id=99, hardware_id=8675309)
with pytest.raises(UnknownRecordError): machine_repository.upsert_machine(new_machine)
machine_repository.update_machine(machine)
assert machine_repository.get_machine_by_id(99) == new_machine
def test_update_machine__storage_error_exception(error_raising_machine_repository): def test_upsert_machine__storage_error_exception(error_raising_machine_repository):
machine = MACHINES[0] machine = MACHINES[0]
with pytest.raises(StorageError): with pytest.raises(StorageError):
error_raising_machine_repository.update_machine(machine) error_raising_machine_repository.upsert_machine(machine)
def test_update_machine__storage_error_update_failed(error_raising_mock_mongo_client): def test_upsert_machine__storage_error_update_failed(error_raising_mock_mongo_client):
mock_result = MagicMock() mock_result = MagicMock()
mock_result.matched_count = 1 mock_result.matched_count = 1
mock_result.modified_count = 0 mock_result.modified_count = 0
@ -158,7 +154,22 @@ def test_update_machine__storage_error_update_failed(error_raising_mock_mongo_cl
machine = MACHINES[0] machine = MACHINES[0]
with pytest.raises(StorageError): with pytest.raises(StorageError):
machine_repository.update_machine(machine) machine_repository.upsert_machine(machine)
def test_upsert_machine__storage_error_insert_failed(error_raising_mock_mongo_client):
mock_result = MagicMock()
mock_result.matched_count = 0
mock_result.upserted_id = None
error_raising_mock_mongo_client.monkey_island.machines.replace_one = MagicMock(
return_value=mock_result
)
machine_repository = MongoMachineRepository(error_raising_mock_mongo_client)
machine = MACHINES[0]
with pytest.raises(StorageError):
machine_repository.upsert_machine(machine)
def test_get_machine_by_id(machine_repository): def test_get_machine_by_id(machine_repository):
@ -235,7 +246,9 @@ def test_reset(machine_repository):
def test_usable_after_reset(machine_repository): def test_usable_after_reset(machine_repository):
machine_repository.reset() machine_repository.reset()
new_machine = machine_repository.create_machine() new_id = machine_repository.get_new_id()
new_machine = Machine(id=new_id)
machine_repository.upsert_machine(new_machine)
assert new_machine == machine_repository.get_machine_by_id(new_machine.id) assert new_machine == machine_repository.get_machine_by_id(new_machine.id)

View File

@ -244,8 +244,8 @@ IConfigRepository.get_config_field
ILogRepository.get_logs ILogRepository.get_logs
ILogRepository.save_log ILogRepository.save_log
ILogRepository.delete_log ILogRepository.delete_log
IMachineRepository.create_machine IMachineRepository.get_new_id
IMachineRepository.update_machine IMachineRepository.upsert_machine
IMachineRepository.get_machine_by_id IMachineRepository.get_machine_by_id
IMachineRepository.get_machine_by_hardware_id IMachineRepository.get_machine_by_hardware_id
IMachineRepository.get_machines_by_ip IMachineRepository.get_machines_by_ip