diff --git a/monkey/monkey_island/cc/repository/i_machine_repository.py b/monkey/monkey_island/cc/repository/i_machine_repository.py index 85e8e71f3..7cea0bb02 100644 --- a/monkey/monkey_island/cc/repository/i_machine_repository.py +++ b/monkey/monkey_island/cc/repository/i_machine_repository.py @@ -53,6 +53,15 @@ class IMachineRepository(ABC): :raises RetrievalError: If an error occurs while attempting to retrieve the `Machine` """ + @abstractmethod + def get_machines(self) -> Sequence[Machine]: + """ + Get all machines in the repository + + :return: A sequence of all stored `Machine`s + :raises RetrievalError: If an error occurs while attempting to retrieve the `Machine`s + """ + @abstractmethod def get_machines_by_ip(self, ip: IPv4Address) -> Sequence[Machine]: """ diff --git a/monkey/monkey_island/cc/repository/mongo_machine_repository.py b/monkey/monkey_island/cc/repository/mongo_machine_repository.py index fab038694..75cbf2d46 100644 --- a/monkey/monkey_island/cc/repository/mongo_machine_repository.py +++ b/monkey/monkey_island/cc/repository/mongo_machine_repository.py @@ -69,6 +69,14 @@ class MongoMachineRepository(IMachineRepository): return Machine(**machine_dict) + def get_machines(self) -> Sequence[Machine]: + try: + cursor = self._machines_collection.find({}, {MONGO_OBJECT_ID_KEY: False}) + except Exception as err: + raise RetrievalError(f"Error retrieving machines: {err}") + + return list(map(lambda m: Machine(**m), cursor)) + def get_machines_by_ip(self, ip: IPv4Address) -> Sequence[Machine]: ip_regex = "^" + str(ip).replace(".", "\\.") + "\\/.*$" query = {"network_interfaces": {"$elemMatch": {"$regex": ip_regex}}} diff --git a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_machine_repository.py b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_machine_repository.py index 90d0af1f2..bbd35283b 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_machine_repository.py +++ b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_machine_repository.py @@ -93,6 +93,11 @@ def machine_repository(mongo_client) -> IMachineRepository: return MongoMachineRepository(mongo_client) +@pytest.fixture +def empty_machine_repository() -> IMachineRepository: + return MongoMachineRepository(mongomock.MongoClient()) + + def test_get_new_id__unique_id(machine_repository): new_machine_id = machine_repository.get_new_id() @@ -107,8 +112,7 @@ def test_get_new_id__multiple_unique_ids(machine_repository): assert id_1 != id_2 -def test_get_new_id__new_id_for_empty_repo(machine_repository): - empty_machine_repository = MongoMachineRepository(mongomock.MongoClient()) +def test_get_new_id__new_id_for_empty_repo(empty_machine_repository): id_1 = empty_machine_repository.get_new_id() id_2 = empty_machine_repository.get_new_id() @@ -202,7 +206,7 @@ def test_get_machine_by_hardware_id__retrieval_error(error_raising_machine_repos error_raising_machine_repository.get_machine_by_hardware_id(1) -def test_get_machine_by_ip(machine_repository): +def test_get_machines_by_ip(machine_repository): expected_machine = MACHINES[0] expected_machine_ip = expected_machine.network_interfaces[0].ip @@ -212,7 +216,7 @@ def test_get_machine_by_ip(machine_repository): assert retrieved_machines[0] == expected_machine -def test_get_machine_by_ip__multiple_results(machine_repository): +def test_get_machines_by_ip__multiple_results(machine_repository): search_ip = MACHINES[3].network_interfaces[0].ip retrieved_machines = machine_repository.get_machines_by_ip(search_ip) @@ -222,16 +226,35 @@ def test_get_machine_by_ip__multiple_results(machine_repository): assert MACHINES[3] in retrieved_machines -def test_get_machine_by_ip__not_found(machine_repository): +def test_get_machines_by_ip__not_found(machine_repository): with pytest.raises(UnknownRecordError): machine_repository.get_machines_by_ip("1.1.1.1") -def test_get_machine_by_ip__retrieval_error(error_raising_machine_repository): +def test_get_machines_by_ip__retrieval_error(error_raising_machine_repository): with pytest.raises(RetrievalError): error_raising_machine_repository.get_machines_by_ip("1.1.1.1") +def test_get_machines(machine_repository): + retrieved_machines = machine_repository.get_machines() + + assert len(retrieved_machines) == len(MACHINES) + for machine in MACHINES: + assert machine in retrieved_machines + + +def test_get_machines__empty_repository(empty_machine_repository): + retrieved_machines = empty_machine_repository.get_machines() + + assert len(retrieved_machines) == 0 + + +def test_get_machines__retrieval_error(error_raising_machine_repository): + with pytest.raises(RetrievalError): + error_raising_machine_repository.get_machines() + + def test_reset(machine_repository): # Ensure the repository is not empty preexisting_machine = machine_repository.get_machine_by_id(MACHINES[0].id)