From 3bc2e4876fac15f34ea86b7237e8248863e95fff Mon Sep 17 00:00:00 2001 From: vakarisz Date: Thu, 6 Oct 2022 14:45:56 +0300 Subject: [PATCH] Island: Handle missing node in add_tcp_connections --- .../monkey_island/cc/repository/mongo_node_repository.py | 3 +++ .../cc/repository/test_mongo_node_repository.py | 7 +++++++ 2 files changed, 10 insertions(+) diff --git a/monkey/monkey_island/cc/repository/mongo_node_repository.py b/monkey/monkey_island/cc/repository/mongo_node_repository.py index 599bddf61..b417a7554 100644 --- a/monkey/monkey_island/cc/repository/mongo_node_repository.py +++ b/monkey/monkey_island/cc/repository/mongo_node_repository.py @@ -51,6 +51,9 @@ class MongoNodeRepository(INodeRepository): def add_tcp_connections(self, machine_id: MachineID, tcp_connections: TCPConnections): node = self._get_node_by_id(machine_id) + if node is None: + node = Node(machine_id=machine_id, connections={}) + for target, connections in tcp_connections.items(): if target in node.tcp_connections: node.tcp_connections[target] = tuple({*node.tcp_connections[target], *connections}) diff --git a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_node_repository.py b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_node_repository.py index d4b16cd46..3b2a0d26f 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_node_repository.py +++ b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_node_repository.py @@ -232,3 +232,10 @@ def test_upsert_tcp_connections__port_already_present(node_repository): modified_node = [node for node in nodes if node.machine_id == 4][0] assert set(modified_node.tcp_connections) == set(ALL_TCP_CONNECTIONS) assert len(modified_node.tcp_connections) == len(ALL_TCP_CONNECTIONS) + + +def test_upsert_tcp_connections__node_missing(node_repository): + node_repository.add_tcp_connections(999, TCP_CONNECTION_PORT_80) + nodes = node_repository.get_nodes() + modified_node = [node for node in nodes if node.machine_id == 999][0] + assert set(modified_node.tcp_connections) == set(TCP_CONNECTION_PORT_80)