diff --git a/monkey/monkey_island/cc/repository/mongo_node_repository.py b/monkey/monkey_island/cc/repository/mongo_node_repository.py index 447bfe8bb..8543d5e21 100644 --- a/monkey/monkey_island/cc/repository/mongo_node_repository.py +++ b/monkey/monkey_island/cc/repository/mongo_node_repository.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Sequence +from typing import Optional, Sequence from pymongo import MongoClient @@ -20,21 +20,18 @@ class MongoNodeRepository(INodeRepository): self, src: MachineID, dst: MachineID, communication_type: CommunicationType ): try: - node_dict = self._nodes_collection.find_one( - {SRC_FIELD_NAME: src}, {MONGO_OBJECT_ID_KEY: False} - ) + node = self._get_node_by_id(src) except Exception as err: raise StorageError(f"{UPSERT_ERROR_MESSAGE}: {err}") - if node_dict is None: + if node is None: updated_node = Node(machine_id=src, connections={dst: frozenset((communication_type,))}) else: - node = Node(**node_dict) updated_node = MongoNodeRepository._add_connection_to_node( node, dst, communication_type ) - self.upsert_node(updated_node) + self._upsert_node(updated_node) @staticmethod def _add_connection_to_node( @@ -70,6 +67,12 @@ class MongoNodeRepository(INodeRepository): f"node, but no nodes were inserted" ) + def _get_node_by_id(self, node_id: MachineID) -> Optional[Node]: + node_dict = self._nodes_collection.find_one( + {SRC_FIELD_NAME: node_id}, {MONGO_OBJECT_ID_KEY: False} + ) + return Node(**node_dict) if node_dict else None + def get_nodes(self) -> Sequence[Node]: try: cursor = self._nodes_collection.find({}, {MONGO_OBJECT_ID_KEY: False})