diff --git a/monkey/monkey_island/cc/agent_event_handlers/scan_event_handler.py b/monkey/monkey_island/cc/agent_event_handlers/scan_event_handler.py index 3c43ddd92..b79fc2558 100644 --- a/monkey/monkey_island/cc/agent_event_handlers/scan_event_handler.py +++ b/monkey/monkey_island/cc/agent_event_handlers/scan_event_handler.py @@ -78,7 +78,12 @@ class ScanEventHandler: def _get_source_node(self, event: ScanEvent) -> Node: machine = self._get_source_machine(event) - return self._node_repository.get_node_by_machine_id(machine.id) + try: + node = self._node_repository.get_node_by_machine_id(machine.id) + except UnknownRecordError: + node = Node(machine_id=machine.id) + self._node_repository.upsert_node(node) + return node def _get_source_machine(self, event: ScanEvent) -> Machine: agent = self._agent_repository.get_agent_by_id(event.source) diff --git a/monkey/monkey_island/cc/models/node.py b/monkey/monkey_island/cc/models/node.py index 95404f7d1..65e6e4b6b 100644 --- a/monkey/monkey_island/cc/models/node.py +++ b/monkey/monkey_island/cc/models/node.py @@ -24,7 +24,7 @@ class Node(MutableInfectionMonkeyBaseModel): machine_id: MachineID = Field(..., allow_mutation=False) """The MachineID of the node (source)""" - connections: NodeConnections + connections: NodeConnections = {} """All outbound connections from this node to other machines""" tcp_connections: TCPConnections = {} diff --git a/monkey/monkey_island/cc/repository/i_node_repository.py b/monkey/monkey_island/cc/repository/i_node_repository.py index 157c9274c..c0c6d3ac6 100644 --- a/monkey/monkey_island/cc/repository/i_node_repository.py +++ b/monkey/monkey_island/cc/repository/i_node_repository.py @@ -44,6 +44,14 @@ class INodeRepository(ABC): :raises RetrievalError: If an error occurs while attempting to retrieve the nodes """ + @abstractmethod + def upsert_node(self, node: Node): + """ + Update or insert Node model into the database + :param node: Node model to be added to the repository + :raises StorageError: If something went wrong when upserting the Node + """ + @abstractmethod def get_node_by_machine_id(self, machine_id: MachineID) -> Node: """ diff --git a/monkey/monkey_island/cc/repository/mongo_node_repository.py b/monkey/monkey_island/cc/repository/mongo_node_repository.py index befc81632..7dee44c88 100644 --- a/monkey/monkey_island/cc/repository/mongo_node_repository.py +++ b/monkey/monkey_island/cc/repository/mongo_node_repository.py @@ -30,7 +30,7 @@ class MongoNodeRepository(INodeRepository): except Exception as err: raise StorageError(f"{UPSERT_ERROR_MESSAGE}: {err}") - self._upsert_node(updated_node) + self.upsert_node(updated_node) @staticmethod def _add_connection_to_node( @@ -57,9 +57,9 @@ class MongoNodeRepository(INodeRepository): node.tcp_connections[target] = tuple({*node.tcp_connections[target], *connections}) else: node.tcp_connections[target] = connections - self._upsert_node(node) + self.upsert_node(node) - def _upsert_node(self, node: Node): + def upsert_node(self, node: Node): try: result = self._nodes_collection.replace_one( {SRC_FIELD_NAME: node.machine_id}, node.dict(simplify=True), upsert=True diff --git a/monkey/tests/unit_tests/monkey_island/cc/agent_event_handlers/test_scan_event_handler.py b/monkey/tests/unit_tests/monkey_island/cc/agent_event_handlers/test_scan_event_handler.py index 55b8f1bce..d997865a0 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/agent_event_handlers/test_scan_event_handler.py +++ b/monkey/tests/unit_tests/monkey_island/cc/agent_event_handlers/test_scan_event_handler.py @@ -22,11 +22,13 @@ from monkey_island.cc.repository import ( SEED_ID = 99 AGENT_ID = UUID("1d8ce743-a0f4-45c5-96af-91106529d3e2") -MACHINE_ID = 11 +SOURCE_MACHINE_ID = 11 CC_SERVER = SocketAddress(ip="10.10.10.100", port="5000") -AGENT = Agent(id=AGENT_ID, machine_id=MACHINE_ID, start_time=0, parent_id=None, cc_server=CC_SERVER) +AGENT = Agent( + id=AGENT_ID, machine_id=SOURCE_MACHINE_ID, start_time=0, parent_id=None, cc_server=CC_SERVER +) SOURCE_MACHINE = Machine( - id=MACHINE_ID, + id=SOURCE_MACHINE_ID, hardware_id=5, network_interfaces=[IPv4Interface("10.10.10.99/24")], ) @@ -125,7 +127,7 @@ def scan_event_handler(agent_repository, machine_repository, node_repository): return ScanEventHandler(agent_repository, machine_repository, node_repository) -MACHINES_BY_ID = {MACHINE_ID: SOURCE_MACHINE, TARGET_MACHINE.id: TARGET_MACHINE} +MACHINES_BY_ID = {SOURCE_MACHINE_ID: SOURCE_MACHINE, TARGET_MACHINE.id: TARGET_MACHINE} MACHINES_BY_IP = { IPv4Address("10.10.10.99"): [SOURCE_MACHINE], IPv4Address(TARGET_MACHINE_IP): [TARGET_MACHINE], @@ -230,14 +232,14 @@ def test_handle_tcp_scan_event__ports_found( scan_event_handler.handle_tcp_scan_event(event) call_args = node_repository.upsert_tcp_connections.call_args[0] - assert call_args[0] == MACHINE_ID + assert call_args[0] == SOURCE_MACHINE_ID assert TARGET_MACHINE_ID in call_args[1] open_socket_addresses = call_args[1][TARGET_MACHINE_ID] assert set(open_socket_addresses) == set(TCP_CONNECTIONS[TARGET_MACHINE_ID]) assert len(open_socket_addresses) == len(TCP_CONNECTIONS[TARGET_MACHINE_ID]) -def test_handle_tcp_scan_event__no_source( +def test_handle_tcp_scan_event__no_source_node( caplog, scan_event_handler, machine_repository, node_repository ): event = TCP_SCAN_EVENT @@ -245,8 +247,11 @@ def test_handle_tcp_scan_event__no_source( scan_event_handler._update_nodes = MagicMock() scan_event_handler.handle_tcp_scan_event(event) - assert "ERROR" in caplog.text - assert "no source" in caplog.text + expected_node = Node(machine_id=SOURCE_MACHINE_ID) + node_called = node_repository.upsert_node.call_args[0][0] + assert expected_node.machine_id == node_called.machine_id + assert expected_node.connections == node_called.connections + assert expected_node.tcp_connections == node_called.tcp_connections @pytest.mark.parametrize(