From be4ecccdcdef0104d2aa5ad6a459b4bf604df53f Mon Sep 17 00:00:00 2001 From: vakarisz Date: Fri, 7 Oct 2022 10:05:06 +0300 Subject: [PATCH] Island: Refactor get_node_by_id to raise UnknownRecordError --- .../scan_event_handler.py | 7 +---- .../cc/repository/i_node_repository.py | 9 ++++++ .../cc/repository/mongo_node_repository.py | 30 +++++++++---------- .../test_scan_event_handler.py | 5 ++-- .../repository/test_mongo_node_repository.py | 10 +++++++ 5 files changed, 38 insertions(+), 23 deletions(-) diff --git a/monkey/monkey_island/cc/agent_event_handlers/scan_event_handler.py b/monkey/monkey_island/cc/agent_event_handlers/scan_event_handler.py index 8c769a470..f3482b891 100644 --- a/monkey/monkey_island/cc/agent_event_handlers/scan_event_handler.py +++ b/monkey/monkey_island/cc/agent_event_handlers/scan_event_handler.py @@ -65,12 +65,7 @@ class ScanEventHandler: def _get_source_node(self, event: AbstractAgentEvent) -> Node: machine = self._get_source_machine(event) - try: - return [ - node for node in self._node_repository.get_nodes() if node.machine_id == machine.id - ][0] - except IndexError: - raise UnknownRecordError(f"Source node for event {event} does not exist") + return self._node_repository.get_node_by_machine_id(machine.id) def _get_target_machine(self, event: ScanEvent) -> Machine: try: diff --git a/monkey/monkey_island/cc/repository/i_node_repository.py b/monkey/monkey_island/cc/repository/i_node_repository.py index 181cd185e..157c9274c 100644 --- a/monkey/monkey_island/cc/repository/i_node_repository.py +++ b/monkey/monkey_island/cc/repository/i_node_repository.py @@ -44,6 +44,15 @@ class INodeRepository(ABC): :raises RetrievalError: If an error occurs while attempting to retrieve the nodes """ + @abstractmethod + def get_node_by_machine_id(self, machine_id: MachineID) -> Node: + """ + Fetches network Node from the database based on Machine id + :param machine_id: ID of a Machine that Node represents + :return: network Node that represents the Machine + :raises UnknownRecordError: If the Node does not exist + """ + @abstractmethod def reset(self): """ diff --git a/monkey/monkey_island/cc/repository/mongo_node_repository.py b/monkey/monkey_island/cc/repository/mongo_node_repository.py index b406b8fad..befc81632 100644 --- a/monkey/monkey_island/cc/repository/mongo_node_repository.py +++ b/monkey/monkey_island/cc/repository/mongo_node_repository.py @@ -1,12 +1,12 @@ from copy import deepcopy -from typing import Optional, Sequence +from typing import Sequence from pymongo import MongoClient from monkey_island.cc.models import CommunicationType, MachineID, Node from ..models.node import TCPConnections -from . import INodeRepository, RemovalError, RetrievalError, StorageError +from . import INodeRepository, RemovalError, RetrievalError, StorageError, UnknownRecordError from .consts import MONGO_OBJECT_ID_KEY UPSERT_ERROR_MESSAGE = "An error occurred while attempting to upsert a node" @@ -21,16 +21,14 @@ class MongoNodeRepository(INodeRepository): self, src: MachineID, dst: MachineID, communication_type: CommunicationType ): try: - node = self._get_node_by_id(src) - except Exception as err: - raise StorageError(f"{UPSERT_ERROR_MESSAGE}: {err}") - - if node is None: - updated_node = Node(machine_id=src, connections={dst: frozenset((communication_type,))}) - else: + node = self.get_node_by_machine_id(src) updated_node = MongoNodeRepository._add_connection_to_node( node, dst, communication_type ) + except UnknownRecordError: + updated_node = Node(machine_id=src, connections={dst: frozenset((communication_type,))}) + except Exception as err: + raise StorageError(f"{UPSERT_ERROR_MESSAGE}: {err}") self._upsert_node(updated_node) @@ -49,9 +47,9 @@ class MongoNodeRepository(INodeRepository): return new_node def upsert_tcp_connections(self, machine_id: MachineID, tcp_connections: TCPConnections): - node = self._get_node_by_id(machine_id) - - if node is None: + try: + node = self.get_node_by_machine_id(machine_id) + except UnknownRecordError: node = Node(machine_id=machine_id, connections={}) for target, connections in tcp_connections.items(): @@ -75,11 +73,13 @@ class MongoNodeRepository(INodeRepository): f"node, but no nodes were inserted" ) - def _get_node_by_id(self, node_id: MachineID) -> Optional[Node]: + def get_node_by_machine_id(self, machine_id: MachineID) -> Node: node_dict = self._nodes_collection.find_one( - {SRC_FIELD_NAME: node_id}, {MONGO_OBJECT_ID_KEY: False} + {SRC_FIELD_NAME: machine_id}, {MONGO_OBJECT_ID_KEY: False} ) - return Node(**node_dict) if node_dict else None + if not node_dict: + raise UnknownRecordError(f"Node with machine ID {machine_id}") + return Node(**node_dict) def get_nodes(self) -> Sequence[Node]: try: diff --git a/monkey/tests/unit_tests/monkey_island/cc/agent_event_handlers/test_scan_event_handler.py b/monkey/tests/unit_tests/monkey_island/cc/agent_event_handlers/test_scan_event_handler.py index e00896e3f..673f8293c 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/agent_event_handlers/test_scan_event_handler.py +++ b/monkey/tests/unit_tests/monkey_island/cc/agent_event_handlers/test_scan_event_handler.py @@ -221,6 +221,7 @@ def test_handle_tcp_scan_event__ports_found( ): event = TCP_SCAN_EVENT scan_event_handler._update_nodes = MagicMock() + node_repository.get_node_by_machine_id.return_value = SOURCE_NODE scan_event_handler.handle_tcp_scan_event(event) call_args = node_repository.upsert_tcp_connections.call_args[0] @@ -235,12 +236,12 @@ def test_handle_tcp_scan_event__no_source( caplog, scan_event_handler, machine_repository, node_repository ): event = TCP_SCAN_EVENT - node_repository.get_nodes.return_value = [] + node_repository.get_node_by_machine_id = MagicMock(side_effect=UnknownRecordError("no source")) scan_event_handler._update_nodes = MagicMock() scan_event_handler.handle_tcp_scan_event(event) assert "ERROR" in caplog.text - assert f"Source node for event {event} does not exist" in caplog.text + assert "no source" in caplog.text @pytest.mark.parametrize( diff --git a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_node_repository.py b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_node_repository.py index fa95043d6..bf4968406 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_node_repository.py +++ b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_node_repository.py @@ -11,6 +11,7 @@ from monkey_island.cc.repository import ( RemovalError, RetrievalError, StorageError, + UnknownRecordError, ) TARGET_MACHINE_IP = "2.2.2.2" @@ -239,3 +240,12 @@ def test_upsert_tcp_connections__node_missing(node_repository): nodes = node_repository.get_nodes() modified_node = [node for node in nodes if node.machine_id == 999][0] assert set(modified_node.tcp_connections) == set(TCP_CONNECTION_PORT_80) + + +def test_get_node_by_machine_id(node_repository): + assert node_repository.get_node_by_machine_id(1) == NODES[0] + + +def test_get_node_by_machine_id__no_node(node_repository): + with pytest.raises(UnknownRecordError): + node_repository.get_node_by_machine_id(999)