From 249950d602b28bfc9874dbf7d5329dad2fdc789b Mon Sep 17 00:00:00 2001 From: vakarisz Date: Wed, 5 Oct 2022 17:07:19 +0300 Subject: [PATCH] Island: Improve tcp handler code and coverage --- .../scan_event_handler.py | 7 ++---- .../test_scan_event_handler.py | 22 +++++++++++++++++-- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/monkey/monkey_island/cc/agent_event_handlers/scan_event_handler.py b/monkey/monkey_island/cc/agent_event_handlers/scan_event_handler.py index 1800eb737..6f3a7030f 100644 --- a/monkey/monkey_island/cc/agent_event_handlers/scan_event_handler.py +++ b/monkey/monkey_island/cc/agent_event_handlers/scan_event_handler.py @@ -70,7 +70,7 @@ class ScanEventHandler: return [ node for node in self._node_repository.get_nodes() if node.machine_id == machine.id ][0] - except KeyError: + except IndexError: raise UnknownRecordError(f"Source node for event {event} does not exist") def _get_target_machine(self, event: ScanEvent) -> Machine: @@ -99,10 +99,7 @@ class ScanEventHandler: def _update_tcp_connections(self, src_node: Node, target_machine: Machine, event: TCPScanEvent): node_connections = dict(deepcopy(src_node.tcp_connections)) - try: - machine_connections = set(node_connections[target_machine.id]) - except KeyError: - machine_connections = set() + machine_connections = set(node_connections.get(target_machine.id, set())) open_ports = [port for port, status in event.ports.items() if status == PortStatus.OPEN] for open_port in open_ports: socket_address = SocketAddress(ip=event.target, port=open_port) diff --git a/monkey/tests/unit_tests/monkey_island/cc/agent_event_handlers/test_scan_event_handler.py b/monkey/tests/unit_tests/monkey_island/cc/agent_event_handlers/test_scan_event_handler.py index f6832b788..160d00ae1 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/agent_event_handlers/test_scan_event_handler.py +++ b/monkey/tests/unit_tests/monkey_island/cc/agent_event_handlers/test_scan_event_handler.py @@ -227,7 +227,10 @@ def test_handle_tcp_scan_event__tcp_connections( scan_event_handler._update_nodes = MagicMock() scan_event_handler.handle_tcp_scan_event(event) - node_repository.upsert_node.assert_called_with(EXPECTED_NODE) + node_passed = node_repository.upsert_node.call_args[0][0] + assert set(node_passed.tcp_connections[TARGET_MACHINE_ID]) == set( + EXPECTED_NODE.tcp_connections[TARGET_MACHINE_ID] + ) def test_handle_tcp_scan_event__tcp_connections_upsert( @@ -238,7 +241,22 @@ def test_handle_tcp_scan_event__tcp_connections_upsert( scan_event_handler._update_nodes = MagicMock() scan_event_handler.handle_tcp_scan_event(event) - node_repository.upsert_node.assert_called_with(EXPECTED_NODE) + node_passed = node_repository.upsert_node.call_args[0][0] + assert set(node_passed.tcp_connections[TARGET_MACHINE_ID]) == set( + EXPECTED_NODE.tcp_connections[TARGET_MACHINE_ID] + ) + + +def test_handle_tcp_scan_event__no_source( + caplog, scan_event_handler, machine_repository, node_repository +): + event = TCP_SCAN_EVENT + node_repository.get_nodes.return_value = [] + scan_event_handler._update_nodes = MagicMock() + + scan_event_handler.handle_tcp_scan_event(event) + assert "ERROR" in caplog.text + assert f"Source node for event {event} does not exist" in caplog.text @pytest.mark.parametrize(