Island: Improve tcp handler code and coverage

This commit is contained in:
vakarisz 2022-10-05 17:07:19 +03:00
parent 6c913895c5
commit 249950d602
2 changed files with 22 additions and 7 deletions

View File

@ -70,7 +70,7 @@ class ScanEventHandler:
return [ return [
node for node in self._node_repository.get_nodes() if node.machine_id == machine.id node for node in self._node_repository.get_nodes() if node.machine_id == machine.id
][0] ][0]
except KeyError: except IndexError:
raise UnknownRecordError(f"Source node for event {event} does not exist") raise UnknownRecordError(f"Source node for event {event} does not exist")
def _get_target_machine(self, event: ScanEvent) -> Machine: def _get_target_machine(self, event: ScanEvent) -> Machine:
@ -99,10 +99,7 @@ class ScanEventHandler:
def _update_tcp_connections(self, src_node: Node, target_machine: Machine, event: TCPScanEvent): def _update_tcp_connections(self, src_node: Node, target_machine: Machine, event: TCPScanEvent):
node_connections = dict(deepcopy(src_node.tcp_connections)) node_connections = dict(deepcopy(src_node.tcp_connections))
try: machine_connections = set(node_connections.get(target_machine.id, set()))
machine_connections = set(node_connections[target_machine.id])
except KeyError:
machine_connections = set()
open_ports = [port for port, status in event.ports.items() if status == PortStatus.OPEN] open_ports = [port for port, status in event.ports.items() if status == PortStatus.OPEN]
for open_port in open_ports: for open_port in open_ports:
socket_address = SocketAddress(ip=event.target, port=open_port) socket_address = SocketAddress(ip=event.target, port=open_port)

View File

@ -227,7 +227,10 @@ def test_handle_tcp_scan_event__tcp_connections(
scan_event_handler._update_nodes = MagicMock() scan_event_handler._update_nodes = MagicMock()
scan_event_handler.handle_tcp_scan_event(event) scan_event_handler.handle_tcp_scan_event(event)
node_repository.upsert_node.assert_called_with(EXPECTED_NODE) node_passed = node_repository.upsert_node.call_args[0][0]
assert set(node_passed.tcp_connections[TARGET_MACHINE_ID]) == set(
EXPECTED_NODE.tcp_connections[TARGET_MACHINE_ID]
)
def test_handle_tcp_scan_event__tcp_connections_upsert( def test_handle_tcp_scan_event__tcp_connections_upsert(
@ -238,7 +241,22 @@ def test_handle_tcp_scan_event__tcp_connections_upsert(
scan_event_handler._update_nodes = MagicMock() scan_event_handler._update_nodes = MagicMock()
scan_event_handler.handle_tcp_scan_event(event) scan_event_handler.handle_tcp_scan_event(event)
node_repository.upsert_node.assert_called_with(EXPECTED_NODE) node_passed = node_repository.upsert_node.call_args[0][0]
assert set(node_passed.tcp_connections[TARGET_MACHINE_ID]) == set(
EXPECTED_NODE.tcp_connections[TARGET_MACHINE_ID]
)
def test_handle_tcp_scan_event__no_source(
caplog, scan_event_handler, machine_repository, node_repository
):
event = TCP_SCAN_EVENT
node_repository.get_nodes.return_value = []
scan_event_handler._update_nodes = MagicMock()
scan_event_handler.handle_tcp_scan_event(event)
assert "ERROR" in caplog.text
assert f"Source node for event {event} does not exist" in caplog.text
@pytest.mark.parametrize( @pytest.mark.parametrize(