diff --git a/monkey/monkey_island/cc/agent_event_handlers/scan_event_handler.py b/monkey/monkey_island/cc/agent_event_handlers/scan_event_handler.py index 0eac3637c..1800eb737 100644 --- a/monkey/monkey_island/cc/agent_event_handlers/scan_event_handler.py +++ b/monkey/monkey_island/cc/agent_event_handlers/scan_event_handler.py @@ -1,12 +1,13 @@ +from copy import deepcopy from ipaddress import IPv4Interface from logging import getLogger from typing import Union from typing_extensions import TypeAlias -from common.agent_events import PingScanEvent, TCPScanEvent -from common.types import PortStatus -from monkey_island.cc.models import CommunicationType, Machine +from common.agent_events import AbstractAgentEvent, PingScanEvent, TCPScanEvent +from common.types import PortStatus, SocketAddress +from monkey_island.cc.models import CommunicationType, Machine, Node from monkey_island.cc.repository import ( IAgentRepository, IMachineRepository, @@ -56,11 +57,22 @@ class ScanEventHandler: try: target_machine = self._get_target_machine(event) + source_node = self._get_source_node(event) self._update_nodes(target_machine, event) + self._update_tcp_connections(source_node, target_machine, event) except (RetrievalError, StorageError, UnknownRecordError): logger.exception("Unable to process tcp scan data") + def _get_source_node(self, event: AbstractAgentEvent) -> Node: + machine = self._get_source_machine(event) + try: + return [ + node for node in self._node_repository.get_nodes() if node.machine_id == machine.id + ][0] + except KeyError: + raise UnknownRecordError(f"Source node for event {event} does not exist") + def _get_target_machine(self, event: ScanEvent) -> Machine: try: target_machines = self._machine_repository.get_machines_by_ip(event.target) @@ -85,6 +97,21 @@ class ScanEventHandler: src_machine.id, target_machine.id, CommunicationType.SCANNED ) + def _update_tcp_connections(self, src_node: Node, target_machine: Machine, event: TCPScanEvent): + node_connections = dict(deepcopy(src_node.tcp_connections)) + try: + machine_connections = set(node_connections[target_machine.id]) + except KeyError: + machine_connections = set() + open_ports = [port for port, status in event.ports.items() if status == PortStatus.OPEN] + for open_port in open_ports: + socket_address = SocketAddress(ip=event.target, port=open_port) + machine_connections.add(socket_address) + + node_connections[target_machine.id] = tuple(machine_connections) + src_node.tcp_connections = node_connections + self._node_repository.upsert_node(src_node) + def _get_source_machine(self, event: ScanEvent) -> Machine: agent = self._agent_repository.get_agent_by_id(event.source) return self._machine_repository.get_machine_by_id(agent.machine_id) diff --git a/monkey/tests/unit_tests/monkey_island/cc/agent_event_handlers/test_scan_event_handler.py b/monkey/tests/unit_tests/monkey_island/cc/agent_event_handlers/test_scan_event_handler.py index ad1ced7fa..f6832b788 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/agent_event_handlers/test_scan_event_handler.py +++ b/monkey/tests/unit_tests/monkey_island/cc/agent_event_handlers/test_scan_event_handler.py @@ -1,3 +1,4 @@ +from copy import deepcopy from ipaddress import IPv4Address, IPv4Interface from itertools import count from unittest.mock import MagicMock @@ -9,7 +10,7 @@ from common import OperatingSystem from common.agent_events import PingScanEvent, TCPScanEvent from common.types import PortStatus, SocketAddress from monkey_island.cc.agent_event_handlers import ScanEventHandler -from monkey_island.cc.models import Agent, CommunicationType, Machine +from monkey_island.cc.models import Agent, CommunicationType, Machine, Node from monkey_island.cc.repository import ( IAgentRepository, IMachineRepository, @@ -29,43 +30,74 @@ SOURCE_MACHINE = Machine( hardware_id=5, network_interfaces=[IPv4Interface("10.10.10.99/24")], ) + +TARGET_MACHINE_ID = 33 +TARGET_MACHINE_IP = "10.10.10.1" TARGET_MACHINE = Machine( - id=33, + id=TARGET_MACHINE_ID, hardware_id=9, - network_interfaces=[IPv4Interface("10.10.10.1/24")], + network_interfaces=[IPv4Interface(f"{TARGET_MACHINE_IP}/24")], +) + +SOURCE_NODE = Node( + machine_id=SOURCE_MACHINE.id, + connections=[], + tcp_connections={ + 44: (SocketAddress(ip="1.1.1.1", port=40), SocketAddress(ip="2.2.2.2", port=50)) + }, +) + +SOURCE_NODE_2 = Node( + machine_id=SOURCE_MACHINE.id, + connections=[], + tcp_connections={ + 44: (SocketAddress(ip="1.1.1.1", port=40), SocketAddress(ip="2.2.2.2", port=50)), + TARGET_MACHINE_ID: (SocketAddress(ip=TARGET_MACHINE_IP, port=22),), + }, +) + +EXPECTED_NODE = Node( + machine_id=SOURCE_MACHINE.id, + connections=[], + tcp_connections={ + 44: (SocketAddress(ip="1.1.1.1", port=40), SocketAddress(ip="2.2.2.2", port=50)), + TARGET_MACHINE_ID: ( + SocketAddress(ip=TARGET_MACHINE_IP, port=22), + SocketAddress(ip=TARGET_MACHINE_IP, port=80), + ), + }, ) PING_SCAN_EVENT = PingScanEvent( source=AGENT_ID, - target=IPv4Address("10.10.10.1"), + target=IPv4Address(TARGET_MACHINE_IP), response_received=True, os=OperatingSystem.LINUX, ) PING_SCAN_EVENT_NO_RESPONSE = PingScanEvent( source=AGENT_ID, - target=IPv4Address("10.10.10.1"), + target=IPv4Address(TARGET_MACHINE_IP), response_received=False, os=OperatingSystem.LINUX, ) PING_SCAN_EVENT_NO_OS = PingScanEvent( source=AGENT_ID, - target=IPv4Address("10.10.10.1"), + target=IPv4Address(TARGET_MACHINE_IP), response_received=True, os=None, ) - TCP_SCAN_EVENT = TCPScanEvent( source=AGENT_ID, - target=IPv4Address("10.10.10.1"), - ports={22: PortStatus.OPEN, 8080: PortStatus.CLOSED}, + target=IPv4Address(TARGET_MACHINE_IP), + ports={22: PortStatus.OPEN, 80: PortStatus.OPEN, 8080: PortStatus.CLOSED}, ) TCP_SCAN_EVENT_CLOSED = TCPScanEvent( source=AGENT_ID, - target=IPv4Address("10.10.10.1"), + target=IPv4Address(TARGET_MACHINE_IP), ports={145: PortStatus.CLOSED, 8080: PortStatus.CLOSED}, ) @@ -91,6 +123,8 @@ def machine_repository() -> IMachineRepository: @pytest.fixture def node_repository() -> INodeRepository: node_repository = MagicMock(spec=INodeRepository) + node_repository.get_nodes.return_value = [deepcopy(SOURCE_NODE)] + node_repository.upsert_node = MagicMock() node_repository.upsert_communication = MagicMock() return node_repository @@ -103,7 +137,7 @@ def scan_event_handler(agent_repository, machine_repository, node_repository): MACHINES_BY_ID = {MACHINE_ID: SOURCE_MACHINE, TARGET_MACHINE.id: TARGET_MACHINE} MACHINES_BY_IP = { IPv4Address("10.10.10.99"): [SOURCE_MACHINE], - IPv4Address("10.10.10.1"): [TARGET_MACHINE], + IPv4Address(TARGET_MACHINE_IP): [TARGET_MACHINE], } @@ -186,6 +220,27 @@ def test_tcp_scan_event_target_machine_not_exists( machine_repository.upsert_machine.assert_called_with(expected_machine) +def test_handle_tcp_scan_event__tcp_connections( + scan_event_handler, machine_repository, node_repository +): + event = TCP_SCAN_EVENT + scan_event_handler._update_nodes = MagicMock() + scan_event_handler.handle_tcp_scan_event(event) + + node_repository.upsert_node.assert_called_with(EXPECTED_NODE) + + +def test_handle_tcp_scan_event__tcp_connections_upsert( + scan_event_handler, machine_repository, node_repository +): + event = TCP_SCAN_EVENT + node_repository.get_nodes.return_value = [deepcopy(SOURCE_NODE_2)] + scan_event_handler._update_nodes = MagicMock() + scan_event_handler.handle_tcp_scan_event(event) + + node_repository.upsert_node.assert_called_with(EXPECTED_NODE) + + @pytest.mark.parametrize( "event,handler", [(PING_SCAN_EVENT, HANDLE_PING_SCAN_METHOD), (TCP_SCAN_EVENT, HANDLE_TCP_SCAN_METHOD)],